In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import numpy as np
import torch
import pickle
import json
import gzip
import random
import os
import re
import sentencepiece as spm
from collections import defaultdict
import copy
from tqdm import tqdm
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
import logging
from pprint import pprint
import collections

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, Dataset, Sampler

from dataclasses import dataclass
from transformers.models.t5.modeling_t5 import (
    T5Stack, T5Block, T5LayerNorm, T5LayerSelfAttention, T5LayerFF, T5LayerCrossAttention,
    T5PreTrainedModel, T5ForConditionalGeneration
)
from transformers import T5Config, T5Tokenizer, PreTrainedTokenizer
from transformers.modeling_outputs import ModelOutput, BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput
from transformers.utils import logging
from transformers.optimization import AdamW, get_linear_schedule_with_warmup

random.seed(42)
torch.manual_seed(42)


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Params

In [None]:
class ModelParams:
    seed = 42
    model_name = "t5-base"
    tokenizer = "p5"
    whole_word_embed = True
    max_text_length = 256
    do_lower_case = True
    batch_size = 4
    optimizer = torch.optim.Adam
    warmup_ratio = 0.05
    weight_decay = 0.01
    clip_grad_norm = -1.0
    gradient_accumulation_steps = 1
    lr = 1e-3
    adam_eps = 1e-6
    adam_beta1 = 0.9
    adam_beta2 = 0.999
    epoch = 1
    dropout = 0.1
    losses = ['rating_loss', 'sequential_loss', 'review_loss', 'metadata_loss', 'recommend_loss', 'total_loss']

    # inference params
    gen_max_length = 64

#### Data Template

In [None]:
'''
Pretraining Tasks -- 5 Prompt Families (1, 2, 3, 4, 5)
'''

all_tasks = {}

# =====================================================
# Task Subgroup 1 -- Rating -- 10 Prompts
# =====================================================

task_subgroup_1 = {}

template = {}

'''
Input template:
Which star rating will user {{user_id}} give item {{item_id}}? (1 being lowest and 5 being highest)


Target template:
{{star_rating}}


Metrics:
Accuracy
'''

template['source'] = "Which star rating will user_{} give item_{} ? ( 1 being lowest and 5 being highest )"
template['target'] = "{}"
template['task'] = "rating"
template['source_argc'] = 2
template['source_argv'] = ['user_id', 'item_id']
template['target_argc'] = 1
template['target_argv'] = ['star_rating']
template['id'] = "1-1"

task_subgroup_1["1-1"] = template

template = {}
'''
Input template:
How will user {{user_id}} rate this business: {{item_title}}? (1 being lowest and 5 being highest)


Target template:
{{star_rating}}


Metrics:
Accuracy
'''
template['source'] = "How will user_{} rate this business : {} ? ( 1 being lowest and 5 being highest )"
template['target'] = "{}"
template['task'] = "rating"
template['source_argc'] = 2
template['source_argv'] = ['user_id', 'item_title']
template['target_argc'] = 1
template['target_argv'] = ['star_rating']
template['id'] = "1-2"

task_subgroup_1["1-2"] = template

template = {}
'''
Input template:
Will user {{user_id}} give item {{item_id}} a {{star_rating}}-star rating? (1 being lowest and 5 being highest)


Target template:
{{answer_choices[label]}} (yes/no)


Metrics:
Accuracy
'''
template['source'] = "Will user_{} give item_{} a {}-star rating ? ( 1 being lowest and 5 being highest )"
template['target'] = "{}"
template['task'] = "rating"
template['source_argc'] = 3
template['source_argv'] = ['user_id', 'item_id', 'star_rating']
template['target_argc'] = 1
template['target_argv'] = ['yes_no']
template['id'] = "1-3"

task_subgroup_1["1-3"] = template

template = {}
'''
Input template:
Does user {{user_id}} like or dislike item {{item_id}}?


Target template:
{{answer_choices[label]}} (like/dislike) – like (4,5) / dislike (1,2,3)

Metrics:
Accuracy
'''
template['source'] = "Does user_{} like or dislike item_{} ?"
template['target'] = "{}"
template['task'] = "rating"
template['source_argc'] = 2
template['source_argv'] = ['user_id', 'item_id']
template['target_argc'] = 1
template['target_argv'] = ['like_dislike']
template['id'] = "1-4"

task_subgroup_1["1-4"] = template

template = {}
'''
Input template:
Predict the user {{user_id}}’s preference on item {{item_id}} ({{item_title}})
-1
-2
-3
-4
-5

Target template:
{{answer_choices[star_rating-1]}}

Metrics:
Accuracy
'''
template['source'] = "Predict the user_{}’s preference on item_{} ( {} ) \n -1 \n -2 \n -3 \n -4 \n -5"
template['target'] = "{}"
template['task'] = "rating"
template['source_argc'] = 3
template['source_argv'] = ['user_id', 'item_id', 'item_title']
template['target_argc'] = 1
template['target_argv'] = ['star_rating']
template['id'] = "1-5"

task_subgroup_1["1-5"] = template

template = {}

'''
Input template:
What star rating do you think {{user_desc}} will give item {{item_id}}? (1 being lowest and 5 being highest)


Target template:
{{star_rating}}


Metrics:
Accuracy
'''

template['source'] = "What star rating do you think {} will give item_{} ? ( 1 being lowest and 5 being highest )"
template['target'] = "{}"
template['task'] = "rating"
template['source_argc'] = 2
template['source_argv'] = ['user_desc', 'item_id']
template['target_argc'] = 1
template['target_argv'] = ['star_rating']
template['id'] = "1-6"

task_subgroup_1["1-6"] = template

template = {}
'''
Input template:
How will {{user_desc}} rate this business: {{item_title}}? (1 being lowest and 5 being highest)


Target template:
{{star_rating}}


Metrics:
Accuracy
'''
template['source'] = "How will {} rate this business : {} ? ( 1 being lowest and 5 being highest )"
template['target'] = "{}"
template['task'] = "rating"
template['source_argc'] = 2
template['source_argv'] = ['user_desc', 'item_title']
template['target_argc'] = 1
template['target_argv'] = ['star_rating']
template['id'] = "1-7"

task_subgroup_1["1-7"] = template

template = {}
'''
Input template:
Will {{user_desc}} give a {{star_rating}}-star rating for {{item_title}}? (1 being lowest and 5 being highest)


Target template:
{{answer_choices[label]}} (yes/no)


Metrics:
Accuracy
'''
template['source'] = "Will {} give a {}-star rating for {} ? ( 1 being lowest and 5 being highest )"
template['target'] = "{}"
template['task'] = "rating"
template['source_argc'] = 3
template['source_argv'] = ['user_desc', 'star_rating', 'item_title']
template['target_argc'] = 1
template['target_argv'] = ['yes_no']
template['id'] = "1-8"

task_subgroup_1["1-8"] = template

template = {}
'''
Input template:
Does {{user_desc}} like or dislike {{item_title}}?


Target template:
{{answer_choices[label]}} (like/dislike) – like (4,5) / dislike (1,2,3)

Metrics:
Accuracy
'''
template['source'] = "Does {} like or dislike {} ?"
template['target'] = "{}"
template['task'] = "rating"
template['source_argc'] = 2
template['source_argv'] = ['user_desc', 'item_title']
template['target_argc'] = 1
template['target_argv'] = ['like_dislike']
template['id'] = "1-9"

task_subgroup_1["1-9"] = template

template = {}
'''
Input template:
Predict {{user_desc}}’s preference towards {{item_title}} (1 being lowest and 5 being highest)

Target template:
{{answer_choices[star_rating-1]}}

Metrics:
Accuracy
'''
template['source'] = "Predict {} ’s preference towards {} ( 1 being lowest and 5 being highest )"
template['target'] = "{}"
template['task'] = "rating"
template['source_argc'] = 2
template['source_argv'] = ['user_desc', 'item_title']
template['target_argc'] = 1
template['target_argv'] = ['star_rating']
template['id'] = "1-10"

task_subgroup_1["1-10"] = template

all_tasks['rating'] = task_subgroup_1

# =====================================================
# Task Subgroup 2 -- Sequential -- 13 Prompts
# =====================================================

task_subgroup_2 = {}

template = {}

'''
Input template:
Given the following visit history of user {{user_id}}:
{{history item list of {{item_id}}}}
predict next possible business to be visited by the user?


Target template:
{{item [item_id]}}


Metrics:
HR, NDCG, MRR
'''

template['source'] = "Given the following visit history of user_{} : \n {} \n predict next possible business to be visited by the user ?"
template['target'] = "{}"
template['task'] = "sequential"
template['source_argc'] = 2
template['source_argv'] = ['user_id', 'visit_history']
template['target_argc'] = 1
template['target_argv'] = ['item_id']
template['id'] = "2-1"

task_subgroup_2["2-1"] = template

template = {}
'''
Input template:
I find the visit history list of user {{user_id}}:
{{history item list of {{item_id}}}}
I wonder which is the next item to recommend to the user. Can you help me decide?


Target template:
{{item [item_id]}}


Metrics:
HR, NDCG, MRR
'''
template['source'] = "I find the visit history list of user_{} : \n {} \n I wonder what is the next item to recommend to the user . Can you help me decide ?"
template['target'] = "{}"
template['task'] = "sequential"
template['source_argc'] = 2
template['source_argv'] = ['user_id', 'visit_history']
template['target_argc'] = 1
template['target_argv'] = ['item_id']
template['id'] = "2-2"

task_subgroup_2["2-2"] = template

template = {}
'''
Input template:
Here is the visit history list of user {{user_id}}:
{{history item list of {{item_id}}}}
try to recommend next item to the user

Target template:
{{item [item_id]}}


Metrics:
HR, NDCG, MRR
'''
template['source'] = "Here is the visit history list of user_{} : \n {} \n try to recommend next item to the user"
template['target'] = "{}"
template['task'] = "sequential"
template['source_argc'] = 2
template['source_argv'] = ['user_id', 'visit_history']
template['target_argc'] = 1
template['target_argv'] = ['item_id']
template['id'] = "2-3"

task_subgroup_2["2-3"] = template

template = {}

'''
Input template:
Given the following visit history of {{user_desc}}:
{{history item list of {{item_id}}}}
predict next possible business for the user


Target template:
{{item [item_id]}}


Metrics:
HR, NDCG, MRR
'''

template['source'] = "Given the following visit history of {} : \n {} \n predict next possible business for the user"
template['target'] = "{}"
template['task'] = "sequential"
template['source_argc'] = 2
template['source_argv'] = ['user_desc', 'visit_history']
template['target_argc'] = 1
template['target_argv'] = ['item_id']
template['id'] = "2-4"

task_subgroup_2["2-4"] = template

template = {}
'''
Input template:
Based on the visit history of {{user_desc}}:
{{history item list of {{item_id}}}}
Can you decide the next business likely to be visited by the user?


Target template:
{{item [item_id]}}


Metrics:
HR, NDCG, MRR
'''
template[
    'source'] = "Based on the visit history of {} : \n {} \n Can you decide the next business likely to be visited by the user ?"
template['target'] = "{}"
template['task'] = "sequential"
template['source_argc'] = 2
template['source_argv'] = ['user_desc', 'visit_history']
template['target_argc'] = 1
template['target_argv'] = ['item_id']
template['id'] = "2-5"

task_subgroup_2["2-5"] = template

template = {}
'''
Input template:
Here is the visit history of {{user_desc}}:
{{history item list of {{item_id}}}}
What to recommend next for the user?

Target template:
{{item [item_id]}}


Metrics:
HR, NDCG, MRR
'''
template['source'] = "Here is the visit history of {} : \n {} \n What to recommend next for the user ?"
template['target'] = "{}"
template['task'] = "sequential"
template['source_argc'] = 2
template['source_argv'] = ['user_desc', 'visit_history']
template['target_argc'] = 1
template['target_argv'] = ['item_id']
template['id'] = "2-6"

task_subgroup_2["2-6"] = template

# Extractive QA
template = {}
'''
Input template:
Here is the visit history of user {{user_id}}:
{{history item list of {{item_id}}}}
Select the next possible business likely to be visited by the user from the following candidates:
{{candidate {{item_id}}}}


Target template:
{{item [item_id]}}


Metrics:
HR, NDCG, MRR
'''
template[
    'source'] = "Here is the visit history of user_{} : \n {} \n Select the next possible business likely to be visited by the user from the following candidates : \n {}"
template['target'] = "{}"
template['task'] = "sequential"
template['source_argc'] = 3
template['source_argv'] = ['user_id', 'visit_history', 'candidates']
template['target_argc'] = 1
template['target_argv'] = ['item_id']
template['id'] = "2-7"

task_subgroup_2["2-7"] = template

template = {}
'''
Input template:
Given the following visit history of {{user_desc}}:
{{history item list of {{item_id}}}}
What to recommend next for the user? Select one from the following items:
{{candidate {{item_id}}}}

Target template:
{{item [item_id]}}


Metrics:
HR, NDCG, MRR
'''
template[
    'source'] = "Given the following visit history of {} : \n {} \n What to recommend next for the user? Select one from the following items : \n {}"
template['target'] = "{}"
template['task'] = "sequential"
template['source_argc'] = 3
template['source_argv'] = ['user_desc', 'visit_history', 'candidates']
template['target_argc'] = 1
template['target_argv'] = ['item_id']
template['id'] = "2-8"

task_subgroup_2["2-8"] = template

template = {}
'''
Input template:
Based on the visit history of user {{user_id}}:
{{history item list of {{item_id}}}}
Choose the next possible visited business from the following candidates:
{{candidate {{item_id}}}}


Target template:
{{item [item_id]}}


Metrics:
HR, NDCG, MRR
'''
template[
    'source'] = "Based on the visit history of user_{} : \n {} \n Choose the next possible visited business from the following candidates : \n {}"
template['target'] = "{}"
template['task'] = "sequential"
template['source_argc'] = 3
template['source_argv'] = ['user_id', 'visit_history', 'candidates']
template['target_argc'] = 1
template['target_argv'] = ['item_id']
template['id'] = "2-9"

task_subgroup_2["2-9"] = template

template = {}
'''
Input template:
I find the visit history list of {{user_desc}}:
{{history item list of {{item_id}}}}
I wonder which is the next item to recommend to the user. Try to select one from the following candidates:
{{candidate {{item_id}}}}

Target template:
{{item [item_id]}}


Metrics:
HR, NDCG, MRR
'''
template[
    'source'] = "I find the visit history list of {} : \n {} \n I wonder which is the next item to recommend to the user . Try to select one from the following candidates : \n {}"
template['target'] = "{}"
template['task'] = "sequential"
template['source_argc'] = 3
template['source_argv'] = ['user_desc', 'visit_history', 'candidates']
template['target_argc'] = 1
template['target_argv'] = ['item_id']
template['id'] = "2-10"

task_subgroup_2["2-10"] = template

template = {}
'''
Input template:
User {{user_id}} has the following visit history:
{{history item list of {{item_id}}}}
Does the user likely to visit {{item [item_id]}} next?

Target template:
{{answer_choices[label]}} (yes/no)

Metrics:
Accuracy
'''
template['source'] = "user_{} has the following visit history : \n {} \n does the user likely to visit {} next ?"
template['target'] = "{}"
template['task'] = "sequential"
template['source_argc'] = 3
template['source_argv'] = ['user_id', 'visit_history', 'item_id']
template['target_argc'] = 1
template['target_argv'] = ['yes_no']
template['id'] = "2-11"

task_subgroup_2["2-11"] = template

template = {}
'''
Input template:
According to {{user_desc}}'s visit history list:
{{history item list of {{item_id}}}}
Predict whether the user will visit {{item [item_id]}} next?

Target template:
{{answer_choices[label]}} (yes/no)

Metrics:
Accuracy
'''
template['source'] = "According to {} 's visit history list : \n {} \n Predict whether the user will visit {} next ?"
template['target'] = "{}"
template['task'] = "sequential"
template['source_argc'] = 3
template['source_argv'] = ['user_desc', 'visit_history', 'item_id']
template['target_argc'] = 1
template['target_argv'] = ['yes_no']
template['id'] = "2-12"

task_subgroup_2["2-12"] = template

template = {}
'''
Input template:
According to the visit history of {{user_desc}}:
{{history item list of {{item_id}}}}
Can you recommend the next possible business to the user?

Target template:
{{item [item_id]}}


Metrics:
HR, NDCG, MRR
'''
template[
    'source'] = "According to the visit history of {} : \n {} \n Can you recommend the next possible business to the user ?"
template['target'] = "{}"
template['task'] = "sequential"
template['source_argc'] = 2
template['source_argv'] = ['user_desc', 'visit_history']
template['target_argc'] = 1
template['target_argv'] = ['item_id']
template['id'] = "2-13"

task_subgroup_2["2-13"] = template

all_tasks['sequential'] = task_subgroup_2

# ====================================================
# Task Subgroup 3 -- Explanation -- 10 Prompts
# ====================================================

task_subgroup_3 = {}

template = {}

'''
Input template:
Generate an explanation for user {{user_id}} about this business: {{item_title}}


Target template:
{{explanation}}


Metrics:
BLUE, ROUGE
'''

template['source'] = "Generate an explanation for user_{} about this business : {}"
template['target'] = "{}"
template['task'] = "explanation"
template['source_argc'] = 2
template['source_argv'] = ['user_id', 'item_title']
template['target_argc'] = 1
template['target_argv'] = ['explanation']
template['id'] = "3-1"

task_subgroup_3["3-1"] = template

template = {}
'''
Input template:
Help user {{user_id}} generate a {{star_rating}}-star explanation about this business:
{{item_title}}


Target template:
{{explanation}}


Metrics:
BLUE, ROUGE
'''
template['source'] = "Help user_{} generate a {}-star explanation about this business : \n {}"
template['target'] = "{}"
template['task'] = "explanation"
template['source_argc'] = 3
template['source_argv'] = ['user_id', 'star_rating', 'item_title']
template['target_argc'] = 1
template['target_argv'] = ['explanation']
template['id'] = "3-2"

task_subgroup_3["3-2"] = template

template = {}

'''
Input template:
Generate an explanation for {{user_desc}} about this business: {{item_title}}


Target template:
{{explanation}}


Metrics:
BLUE, ROUGE
'''

template['source'] = "Generate an explanation for {} about this business : {}"
template['target'] = "{}"
template['task'] = "explanation"
template['source_argc'] = 2
template['source_argv'] = ['user_desc', 'item_title']
template['target_argc'] = 1
template['target_argv'] = ['explanation']
template['id'] = "3-3"

task_subgroup_3["3-3"] = template

template = {}
'''
Input template:
Help {{user_desc}} generate a {{star_rating}}-star explanation for item {{item_id}}


Target template:
{{explanation}}


Metrics:
BLUE, ROUGE
'''
template['source'] = "Help {} generate a {}-star explanation for item_{}"
template['target'] = "{}"
template['task'] = "explanation"
template['source_argc'] = 3
template['source_argv'] = ['user_desc', 'star_rating', 'item_id']
template['target_argc'] = 1
template['target_argv'] = ['explanation']
template['id'] = "3-4"

task_subgroup_3["3-4"] = template

template = {}

'''
Input template:
Predict the star rating, then use {{feature}} as feature word to generate user {{user_id}} 's visit explanation for item {{item_id}}


Target template:
{{star_rating}}, {{explanation}}


Metrics:
BLUE, ROUGE
'''

template[
    'source'] = "Predict the star rating , then use {} as feature word to generate user_{} 's visit explanation for item_{}"
template['target'] = "{} , {}"
template['task'] = "explanation"
template['source_argc'] = 3
template['source_argv'] = ['feature', 'user_id', 'item_id']
template['target_argc'] = 2
template['target_argv'] = ['star_rating', 'explanation']
template['id'] = "3-5"

task_subgroup_3["3-5"] = template

template = {}

'''
Input template:
What score will {{user_desc}} rate item {{item_id}}? Then give an explanation for the rating score. (1 being lowest and 5 being highest)


Target template:
{{star_rating}}, {{explanation}}


Metrics:
BLUE, ROUGE
'''

template[
    'source'] = "What score will {} rate item_{} ? Then give an explanation for the rating score . ( 1 being lowest and 5 being highest )"
template['target'] = "{} , {}"
template['task'] = "explanation"
template['source_argc'] = 2
template['source_argv'] = ['user_desc', 'item_id']
template['target_argc'] = 2
template['target_argv'] = ['star_rating', 'explanation']
template['id'] = "3-6"

task_subgroup_3["3-6"] = template

template = {}
'''
Name:
Input template:
Based on the feature word {{feature}}, generate an explanation for user {{user_id}} about this business: {{item_title}}


Target template:
{{explanation}}


Metrics:
BLUE, ROUGE
'''

template['source'] = "Based on the feature word {} , generate an explanation for user_{} about this business : {}"
template['target'] = "{}"
template['task'] = "explanation"
template['source_argc'] = 3
template['source_argv'] = ['feature', 'user_id', 'item_title']
template['target_argc'] = 1
template['target_argv'] = ['explanation']
template['id'] = "3-7"

task_subgroup_3["3-7"] = template

template = {}
'''
Input template:

Given the word {{feature}}, can you help generate an explanation for {{user_desc}} about the business: \n {{item_title}}


Target template:
{{explanation}}


Metrics:
BLUE, ROUGE
'''

template['source'] = "Given the word {} , can you help generate an explanation for {} about the business : \n {}"
template['target'] = "{}"
template['task'] = "explanation"
template['source_argc'] = 3
template['source_argv'] = ['feature', 'user_desc', 'item_title']
template['target_argc'] = 1
template['target_argv'] = ['explanation']
template['id'] = "3-8"

task_subgroup_3["3-8"] = template

template = {}
'''
Name:
Input template:
Using the word {{feature}}, write a {{star_rating}}-star explanation for user {{user_id}} about item {{item_id}}


Target template:
{{explanation}}


Metrics:
BLUE, ROUGE
'''

template['source'] = "Using the word {} , write a {}-star explanation for user_{} about item_{}"
template['target'] = "{}"
template['task'] = "explanation"
template['source_argc'] = 4
template['source_argv'] = ['feature', 'star_rating', 'user_id', 'item_id']
template['target_argc'] = 1
template['target_argv'] = ['explanation']
template['id'] = "3-9"

task_subgroup_3["3-9"] = template

template = {}
'''
Name:
Input template:
According to the feature word {{feature}}, generate a {{star_rating}}-star explanation for {{user_desc}} about item {{item_id}}


Target template:
{{explanation}}


Metrics:
BLUE, ROUGE
'''

template['source'] = "According to the feature word {} , generate a {}-star explanation for {} about item_{}"
template['target'] = "{}"
template['task'] = "explanation"
template['source_argc'] = 4
template['source_argv'] = ['feature', 'star_rating', 'user_desc', 'item_id']
template['target_argc'] = 1
template['target_argv'] = ['explanation']
template['id'] = "3-10"

task_subgroup_3["3-10"] = template

all_tasks['explanation'] = task_subgroup_3

# ====================================================
# Task Subgroup 4 -- Review -- 3 Prompts
# ====================================================

task_subgroup_4 = {}

template = {}
'''
Input template:
Predict the associated rating score of the review written by user {{user_id}} (1 being lowest and 5 being highest):
{{review_body}}


Target template:
{{star_rating}}


Metrics:
Accuracy
'''
template[
    'source'] = "Predict the associated rating score of the review written by user_{} ( 1 being lowest and 5 being highest ) : \n {}"
template['target'] = "{}"
template['task'] = "review"
template['source_argc'] = 2
template['source_argv'] = ['user_id', 'review_body']
template['target_argc'] = 1
template['target_argv'] = ['star_rating']
template['id'] = "4-1"

task_subgroup_4["4-1"] = template

template = {}
'''
Input template:
Given the following review written by user {{user_id}}:
{{review_body}}
Can you predict the associated star rating (1 being lowest and 5 being highest)?


Target template:
{{star_rating}}


Metrics:
Accuracy
'''
template[
    'source'] = "Given the following review written by user_{} : \n {} \n Can you predict the associated star rating ? ( 1 being lowest and 5 being highest )"
template['target'] = "{}"
template['task'] = "review"
template['source_argc'] = 2
template['source_argv'] = ['user_id', 'review_body']
template['target_argc'] = 1
template['target_argv'] = ['star_rating']
template['id'] = "4-2"

task_subgroup_4["4-2"] = template

template = {}
'''
Input template:
According to the following review written by {{user_desc}}:
{{review_body}}
Predict the associated star rating (1 being lowest and 5 being highest)


Target template:
{{star_rating}}


Metrics:
Accuracy
'''
template[
    'source'] = "According to the following review written by {} : \n {} \n Predict the associated star rating ( 1 being lowest and 5 being highest )"
template['target'] = "{}"
template['task'] = "review"
template['source_argc'] = 2
template['source_argv'] = ['user_desc', 'review_body']
template['target_argc'] = 1
template['target_argv'] = ['star_rating']
template['id'] = "4-3"

task_subgroup_4["4-3"] = template

all_tasks['review'] = task_subgroup_4

# =====================================================
# Task Subgroup 5 -- Traditional -- 8 Prompts
# =====================================================

task_subgroup_5 = {}

## Interaction Prediction (Binary Classification) - 4 prompts

template = {}

'''
Input template:
Will user {{user_id}} likely to interact with item {{item_id}}?


Target template:
{{answer_choices[label]}} (yes/no)


Metrics:
Accuracy (HR, NDCG, MRRs)
'''

template['source'] = "Will user_{} likely to interact with item_{} ?"
template['target'] = "{}"
template['task'] = "traditional"
template['source_argc'] = 2
template['source_argv'] = ['user_id', 'item_id']
template['target_argc'] = 1
template['target_argv'] = ['yes_no']
template['id'] = "5-1"

task_subgroup_5["5-1"] = template

template = {}

'''
Input template:
Shall we recommend item {{item_id}} to {{user_desc}}?


Target template:
{{answer_choices[label]}} (yes/no)


Metrics:
Accuracy (HR, NDCG, MRRs)
'''

template['source'] = "Shall we recommend item_{} to {} ?"
template['target'] = "{}"
template['task'] = "traditional"
template['source_argc'] = 2
template['source_argv'] = ['item_id', 'user_desc']
template['target_argc'] = 1
template['target_argv'] = ['yes_no']
template['id'] = "5-2"

task_subgroup_5["5-2"] = template

template = {}

'''
Input template:
For {{user_desc}}, do you think it is good to recommend {{item_title}}?


Target template:
{{answer_choices[label]}} (yes/no)


Metrics:
Accuracy (HR, NDCG, MRRs)
'''

template['source'] = "For {}, do you think it is good to recommend {} ?"
template['target'] = "{}"
template['task'] = "traditional"
template['source_argc'] = 2
template['source_argv'] = ['user_desc', 'item_title']
template['target_argc'] = 1
template['target_argv'] = ['yes_no']
template['id'] = "5-3"

task_subgroup_5["5-3"] = template

template = {}

'''
Input template:
I would like to recommend some items for user {{user_id}}. Is the following item a good choice?
{{item_title}}


Target template:
{{answer_choices[label]}} (yes/no)


Metrics:
Accuracy (HR, NDCG, MRRs)
'''

template['source'] = "I would like to recommend some items for user_{} . Is the following item a good choice ? \n {}"
template['target'] = "{}"
template['task'] = "traditional"
template['source_argc'] = 2
template['source_argv'] = ['user_id', 'item_title']
template['target_argc'] = 1
template['target_argv'] = ['yes_no']
template['id'] = "5-4"

task_subgroup_5["5-4"] = template

## Extractive QA - 4 prompts
template = {}

'''
Input template:
Which item of the following to recommend for {{user_desc}}?
{{candidate {{item_id}}}}


Target template:
{{groundtruth {{item ids}}}}


Metrics:
HR, NDCG, MRR
'''

template['source'] = "Which item of the following to recommend for {} ? \n {}"
template['target'] = "{}"
template['task'] = "traditional"
template['source_argc'] = 2
template['source_argv'] = ['user_desc', 'candidates']
template['target_argc'] = 1
template['target_argv'] = ['groundtruth_item_ids']
template['id'] = "5-5"

task_subgroup_5["5-5"] = template

template = {}

'''
Input template:
Choose the best item from the candidates to recommend for {{user_desc}}?
{{candidate {{item_id}}}}


Target template:
{{groundtruth {{item ids}}}}


Metrics:
HR, NDCG, MRR
'''

template['source'] = "Choose the best item from the candidates to recommend for {} ? \n {}"
template['target'] = "{}"
template['task'] = "traditional"
template['source_argc'] = 2
template['source_argv'] = ['user_desc', 'candidates']
template['target_argc'] = 1
template['target_argv'] = ['groundtruth_item_ids']
template['id'] = "5-6"

task_subgroup_5["5-6"] = template

template = {}

'''
Input template:
Pick the most suitable item from the following list and recommend to user {{user_id}}:
{{candidate {{item_id}}}}


Target template:
{{groundtruth {{item ids}}}}


Metrics:
HR, NDCG, MRR
'''

template['source'] = "Pick the most suitable item from the following list and recommend to user_{} : \n {}"
template['target'] = "{}"
template['task'] = "traditional"
template['source_argc'] = 2
template['source_argv'] = ['user_id', 'candidates']
template['target_argc'] = 1
template['target_argv'] = ['groundtruth_item_ids']
template['id'] = "5-7"

task_subgroup_5["5-7"] = template

template = {}

'''
Input template:
We want to make recommendation for user {{user_id}}. Select the best item from these candidates:
{{candidate {{item_id}}}}


Target template:
{{groundtruth {{item ids}}}}


Metrics:
HR, NDCG, MRR
'''

template['source'] = "We want to make recommendation for user_{} .  Select the best item from these candidates : \n {}"
template['target'] = "{}"
template['task'] = "traditional"
template['source_argc'] = 2
template['source_argv'] = ['user_id', 'candidates']
template['target_argc'] = 1
template['target_argv'] = ['groundtruth_item_ids']
template['id'] = "5-8"

task_subgroup_5["5-8"] = template

all_tasks['traditional'] = task_subgroup_5

task_templates = all_tasks

#### Utils

In [None]:
class LossMeter(object):
    def __init__(self, maxlen=100):
        """Computes and stores the running average"""
        self.vals = collections.deque([], maxlen=maxlen)

    def __len__(self):
        return len(self.vals)

    def update(self, new_val):
        self.vals.append(new_val)

    @property
    def val(self):
        return sum(self.vals) / len(self.vals)

    def __repr__(self):
        return str(self.val)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def load_state_dict(state_dict_path, loc='cpu'):
    state_dict = torch.load(state_dict_path, map_location=loc)
    # Change Multi GPU to single GPU
    original_keys = list(state_dict.keys())
    for key in original_keys:
        if key.startswith("module."):
            new_key = key[len("module."):]
            state_dict[new_key] = state_dict.pop(key)
    return state_dict


def set_global_logging_level(level=logging.ERROR, prefices=[""]):
    """
    Override logging levels of different modules based on their name as a prefix.
    It needs to be invoked after the modules have been loaded so that their loggers have been initialized.

    Args:
        - level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR
        - prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional.
          Default is `[""]` to match all active loggers.
          The match is a case-sensitive `module_name.startswith(prefix)`
    """
    prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })')
    for name in logging.root.manager.loggerDict:
        if re.match(prefix_re, name):
            logging.getLogger(name).setLevel(level)

#### Tokenizer

In [None]:
class CustomTokenizer(T5Tokenizer):

    def __init__(self, vocab_file, eos_token="</s>", unk_token="<unk>", pad_token="<pad>", extra_ids=100, user_extra_ids=0, item_extra_ids=0, **kwargs):
        self.vocab_file = vocab_file
        self._extra_ids = extra_ids
        self._user_extra_ids = user_extra_ids
        self._item_extra_ids = item_extra_ids
        self.additional_special_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
        if user_extra_ids>0:
            self.additional_special_tokens.extend([f"<user_id_{i}>" for i in range(user_extra_ids)])
        if item_extra_ids>0:
            self.additional_special_tokens.extend([f"<item_id_{i}>" for i in range(item_extra_ids)])

        super().__init__(vocab_file, eos_token=eos_token, unk_token=unk_token, pad_token=pad_token, **kwargs)
        self.add_special_tokens({"additional_special_tokens": self.additional_special_tokens})
        self.sp_model = spm.SentencePieceProcessor()
        self.sp_model.Load(vocab_file)

    @property
    def vocab_size(self):
        return self.sp_model.get_piece_size() + self._extra_ids + self._user_extra_ids + self._item_extra_ids

    def get_vocab(self):
        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
        vocab.update(self.added_tokens_encoder)
        return vocab

    def _convert_token_to_id(self, token):
        if token.startswith("<extra_id_"):
            match = re.match(r"<extra_id_(\d+)>", token)
            num = int(match.group(1))
            return self.vocab_size - num - 1 - self._user_extra_ids - self._item_extra_ids
        elif "<user_id_" in token:
            match = re.match(r"<user_id_(\d+)>", token)
            num = int(match.group(1))
            return self.vocab_size - num - 1 - self._item_extra_ids
        elif "<item_id_" in token:
            match = re.match(r"<item_id_(\d+)>", token)
            num = int(match.group(1))
            return self.vocab_size - num - 1
        return self.sp_model.piece_to_id(token)


    def _convert_id_to_token(self, index):
        if index < self.sp_model.get_piece_size():
            token = self.sp_model.IdToPiece(index)
        else:
            if index > self.sp_model.get_piece_size() + self._extra_ids + self._user_extra_ids - 1:
                token = "<item_id_{}>".format(self.vocab_size - 1 - index)
            elif index > self.sp_model.get_piece_size() + self._extra_ids - 1:
                token = "<user_id_{}>".format(self.vocab_size - self._item_extra_ids - 1 - index)
            else:
                token = "<extra_id_{}>".format(self.vocab_size - self._user_extra_ids - self._item_extra_ids - 1 - index)
        return token

#### Get Data

In [None]:
def load_json(file_path):
    with open(file_path, "r") as f:
        return json.load(f)
def load_pickle(filename):
    with open(filename, "rb") as f:
        return pickle.load(f)
def ReadLineFromFile(path):
    lines = []
    with open(path, 'r') as fd:
        for line in fd:
            lines.append(line.rstrip('\n'))
    return lines
def parse(path):
    g = gzip.open(path, 'r')
    for l in g:
        yield eval(l)


class P5YelpDataset(Dataset):
    def __init__(self, all_tasks, task_list, tokenizer, args, sample_numbers, mode='train', rating_augment=False, sample_type='random'):
        self.all_tasks = all_tasks
        self.task_list = task_list
        self.tokenizer = tokenizer
        self.args = args
        self.sample_numbers = sample_numbers
        self.rating_augment = rating_augment
        self.sample_type = sample_type
        self.mode = mode
        self.data_source_folder = "drive/MyDrive/P5 Recommender/data/"

        self.review_data = load_pickle(self.data_source_folder + "review_splits.pkl")[self.mode]
        self.exp_data = load_pickle(self.data_source_folder + "exp_splits.pkl")[self.mode]
        if self.rating_augment:
            self.rating_data = load_pickle(self.data_source_folder + "rating_splits_augmented.pkl")[self.mode]
        else:
            self.rating_data = self.review_data

        # what all restaurants the user has visited,  user_id1 business_id1 business_id2 business_id3 ....
        self.sequential_data = ReadLineFromFile(self.data_source_folder + "sequential_data.txt")
        item_count = defaultdict(int)
        user_items = defaultdict()
        for line in self.sequential_data:
            user, items = line.strip().split(' ', 1)
            items = items.split(' ')
            items = [int(item) for item in items]
            user_items[user] = items
            for item in items:
                item_count[item] += 1
        self.all_item = list(item_count.keys())
        count = list(item_count.values())
        sum_value = np.sum([x for x in count])
        # probability of picking one item
        self.probability = [value / sum_value for value in count]
        self.user_items = user_items

        if self.mode == 'test':
            self.negative_samples = ReadLineFromFile(self.data_source_folder + "negative_samples.txt")

        # this json contains just the mappings, user2id, id2user, item2id, id2item
        # total users-30431,  items-20033
        datamaps = load_json(self.data_source_folder + "datamaps.json")
        self.user2id = datamaps['user2id']
        self.item2id = datamaps['item2id']
        self.user_list = list(datamaps['user2id'].keys())
        self.item_list = list(datamaps['item2id'].keys())
        self.id2item = datamaps['id2item']

        # we have 30431 users, user_id: user_name
        self.user_id2name = load_pickle(self.data_source_folder + "user_id2name.pkl")
        # we have 20033 restaurants, for each restaurant name, demographic details, attributes such as what ambience, parking, alcohol etc are present
        self.meta_data = load_pickle(self.data_source_folder + "meta_data.pkl")
        # for all the 30431 users, we have data like user_id, name, total reviews, friends, when started using yelp, different compliments
        self.user_data = load_pickle(self.data_source_folder + "user_data.pkl")

        # just a mapping of restaurant / business_id and index
        self.meta_dict = {}
        for i, meta_item in enumerate(self.meta_data):
            self.meta_dict[meta_item['business_id']] = i
        print("Total number of businesses: stored in self.meta_dict", len(self.meta_dict))
        self.user_meta_dict = {}
        for j, user_meta_item in enumerate(self.user_data):
            self.user_meta_dict[user_meta_item['user_id']] = j
        print("Total umber of users: stored in self.user_meta_dict", len(self.user_meta_dict))

        print('compute_datum_info')
        self.total_length = 0
        self.datum_info = []
        self.compute_datum_info()
        print("Init call finished")

    def compute_datum_info(self):
        curr = 0
        for key in list(self.task_list.keys()):
            if key == 'rating':  #
                self.total_length += len(self.rating_data) * self.sample_numbers[key]
                for i in range(self.total_length - curr):
                    self.datum_info.append((i + curr, key, i // self.sample_numbers[key]))
                curr = self.total_length
            elif key == 'sequential':
                if sum([0 < int(ind.split('-')[1]) <= 6 or int(ind.split('-')[1]) == 13 for ind in self.task_list[key]]):
                    self.total_length += len(self.sequential_data) * self.sample_numbers[key][0]
                    for i in range(self.total_length - curr):
                        self.datum_info.append((i + curr, key, i // self.sample_numbers[key][0]))
                    curr = self.total_length
                if sum([6 < int(ind.split('-')[1]) <= 10 for ind in self.task_list[key]]):
                    self.total_length += len(self.sequential_data) * self.sample_numbers[key][1]
                    for i in range(self.total_length - curr):
                        self.datum_info.append((i + curr, key, i // self.sample_numbers[key][1]))
                    curr = self.total_length
                if sum([10 < int(ind.split('-')[1]) <= 12 for ind in self.task_list[key]]):
                    self.total_length += len(self.sequential_data) * self.sample_numbers[key][2]
                    for i in range(self.total_length - curr):
                        self.datum_info.append((i + curr, key, i // self.sample_numbers[key][2]))
                    curr = self.total_length
            elif key == 'explanation':
                self.total_length += len(self.exp_data) * self.sample_numbers[key]
                for i in range(self.total_length - curr):
                    self.datum_info.append((i + curr, key, i // self.sample_numbers[key]))
                curr = self.total_length
            elif key == 'review':
                self.total_length += len(self.review_data) * self.sample_numbers[key]
                for i in range(self.total_length - curr):
                    self.datum_info.append((i + curr, key, i // self.sample_numbers[key]))
                curr = self.total_length
            elif key == 'traditional':
                if sum([0 < int(ind.split('-')[1]) <= 4 for ind in self.task_list[key]]):
                    self.total_length += len(self.user2id) * self.sample_numbers[key][0]
                    for i in range(self.total_length - curr):
                        self.datum_info.append((i + curr, key, i // self.sample_numbers[key][0]))
                    curr = self.total_length
                if sum([4 < int(ind.split('-')[1]) <= 8 for ind in self.task_list[key]]):
                    self.total_length += len(self.user2id) * self.sample_numbers[key][1]
                    for i in range(self.total_length - curr):
                        self.datum_info.append((i + curr, key, i // self.sample_numbers[key][1]))
                    curr = self.total_length
            else:
                raise NotImplementedError

    def gaussian_sampling(self, datum):
        if self.mode == 'train':
            if int(datum['overall']) == 1:
                sampled_rating = round(torch.normal(mean=torch.tensor((1.0+1.4)/2), std=torch.tensor((1.4-1.0)/4)).item(), 1)
            elif int(datum['overall']) == 2:
                sampled_rating = round(torch.normal(mean=torch.tensor((1.5+2.4)/2), std=torch.tensor((2.4-1.5)/4)).item(), 1)
            elif int(datum['overall']) == 3:
                sampled_rating = round(torch.normal(mean=torch.tensor((2.5+3.4)/2), std=torch.tensor((3.4-2.5)/4)).item(), 1)
            elif int(datum['overall']) == 4:
                sampled_rating = round(torch.normal(mean=torch.tensor((3.5+4.4)/2), std=torch.tensor((4.4-3.5)/4)).item(), 1)
            else:
                sampled_rating = round(torch.normal(mean=torch.tensor((4.5+5.0)/2), std=torch.tensor((5.0-4.5)/4)).item(), 1)
            if sampled_rating > 5.0:
                sampled_rating = 5.0
            if sampled_rating < 1.0:
                sampled_rating = 1.0
            return str(sampled_rating)
        else:
            return int(datum['overall'])

    def collate_fn(self, batch):
        batch_entry = {}
        B = len(batch)
        args = self.args
        S_W_L = max(entry['input_length'] for entry in batch)
        T_W_L = max(entry['target_length'] for entry in batch)

        input_ids = torch.ones(B, S_W_L, dtype=torch.long) * self.tokenizer.pad_token_id
        whole_word_ids = torch.ones(B, S_W_L, dtype=torch.long) * self.tokenizer.pad_token_id
        target_ids = torch.ones(B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id
        loss_weights = torch.ones(B, dtype=torch.float)

        tasks = []
        source_text = []
        tokenized_text = []
        target_text = []

        for i, entry in enumerate(batch):
            input_ids[i, :entry['input_length']] = entry['input_ids']
            whole_word_ids[i, :entry['input_length']] = entry['whole_word_ids']
            target_ids[i, :entry['target_length']] = entry['target_ids']

            if 'task' in entry:
                tasks.append(entry['task'])

            if 'source_text' in entry:
                source_text.append(entry['source_text'])

            if 'tokenized_text' in entry:
                tokenized_text.append(entry['tokenized_text'])

            if 'target_text' in entry:
                target_text.append(entry['target_text'])

            if 'loss_weight' in entry:
                loss_weights[i] = entry['loss_weight']

        word_mask = target_ids != self.tokenizer.pad_token_id
        target_ids[~word_mask] = -100
        batch_entry['task'] = tasks
        batch_entry['source_text'] = source_text
        batch_entry['target_text'] = target_text
        batch_entry['input_ids'] = input_ids
        batch_entry['whole_word_ids'] = whole_word_ids
        batch_entry['target_ids'] = target_ids
        batch_entry['loss_weights'] = loss_weights

        return batch_entry

    def calculate_whole_word_ids(self, tokenized_text, input_ids):
        whole_word_ids = []
        curr = 0
        for i in range(len(tokenized_text)):
            if tokenized_text[i].startswith('▁'):
                curr += 1
                whole_word_ids.append(curr)
            else:
                whole_word_ids.append(curr)
        last_item = whole_word_ids[len(input_ids) - 2]
        return whole_word_ids[:len(input_ids) - 1] + [0] # [0] for </s>

    def __len__(self):
        return self.total_length

    def __getitem__(self, idx):

        out_dict = {}
        out_dict['args'] = self.args

        loss_weight = 1.0

        datum_info_idx = self.datum_info[idx]
        task_name = datum_info_idx[1]
        datum_idx = datum_info_idx[2]

        if task_name == 'rating':
            rating_datum = self.rating_data[datum_idx]
            task_candidates = self.task_list[task_name]
            task_idx = random.randint(0, len(task_candidates) - 1)  # random choose the task index for task_candidates
            task_template = self.all_tasks['rating'][task_candidates[task_idx]]
            assert task_template['task'] == 'rating'

            if task_template['id'] == '1-1':
                source_text = task_template['source'].format(self.user2id[rating_datum['reviewerID']],
                                                             self.item2id[rating_datum['asin']])
                target_text = task_template['target'].format(self.gaussian_sampling(rating_datum))
            elif task_template['id'] == '1-2':
                if 'name' in self.meta_data[self.meta_dict[rating_datum['asin']]]:
                    title = self.meta_data[self.meta_dict[rating_datum['asin']]]['name']
                else:
                    title = 'unknown name'
                source_text = task_template['source'].format(self.user2id[rating_datum['reviewerID']], title)
                target_text = task_template['target'].format(self.gaussian_sampling(rating_datum))
            elif task_template['id'] == '1-3':
                rand_prob = random.random()
                if rand_prob > 0.5:
                    source_text = task_template['source'].format(self.user2id[rating_datum['reviewerID']],
                                                                 self.item2id[rating_datum['asin']],
                                                                 int(rating_datum['overall']))
                    target_text = task_template['target'].format('yes')
                else:
                    overall_candidates = [_ for _ in range(0 + 1, 5 + 1) if _ != int(rating_datum['overall'])]
                    overall_idx = random.randint(0, len(overall_candidates) - 1)  # random choose the overall index for overall_candidates
                    source_text = task_template['source'].format(self.user2id[rating_datum['reviewerID']],
                                                                 self.item2id[rating_datum['asin']],
                                                                 overall_candidates[overall_idx])
                    target_text = task_template['target'].format('no')
            elif task_template['id'] == '1-4':
                source_text = task_template['source'].format(self.user2id[rating_datum['reviewerID']],
                                                             self.item2id[rating_datum['asin']])
                if int(rating_datum['overall']) >= 4:
                    target_text = task_template['target'].format('like')
                else:
                    target_text = task_template['target'].format('dislike')
            elif task_template['id'] == '1-5':
                if 'name' in self.meta_data[self.meta_dict[rating_datum['asin']]]:
                    title = self.meta_data[self.meta_dict[rating_datum['asin']]]['name']
                else:
                    title = 'unknown name'
                source_text = task_template['source'].format(self.user2id[rating_datum['reviewerID']],
                                                             self.item2id[rating_datum['asin']], title)
                target_text = task_template['target'].format(self.gaussian_sampling(rating_datum))
            elif task_template['id'] == '1-6':
                if 'name' in self.user_data[self.user_meta_dict[rating_datum['reviewerID']]]:
                    user_desc = self.user_data[self.user_meta_dict[rating_datum['reviewerID']]]['name']
                else:
                    user_desc = rating_datum['reviewerID']
                source_text = task_template['source'].format(user_desc, self.item2id[rating_datum['asin']])
                target_text = task_template['target'].format(self.gaussian_sampling(rating_datum))
            elif task_template['id'] == '1-7':
                if 'name' in self.user_data[self.user_meta_dict[rating_datum['reviewerID']]]:
                    user_desc = self.user_data[self.user_meta_dict[rating_datum['reviewerID']]]['name']
                else:
                    user_desc = rating_datum['reviewerID']
                if 'name' in self.meta_data[self.meta_dict[rating_datum['asin']]]:
                    title = self.meta_data[self.meta_dict[rating_datum['asin']]]['name']
                else:
                    title = 'unknown name'
                source_text = task_template['source'].format(user_desc, title)
                target_text = task_template['target'].format(self.gaussian_sampling(rating_datum))
            elif task_template['id'] == '1-8':
                rand_prob = random.random()
                if 'name' in self.user_data[self.user_meta_dict[rating_datum['reviewerID']]]:
                    user_desc = self.user_data[self.user_meta_dict[rating_datum['reviewerID']]]['name']
                else:
                    user_desc = rating_datum['reviewerID']
                if 'name' in self.meta_data[self.meta_dict[rating_datum['asin']]]:
                    title = self.meta_data[self.meta_dict[rating_datum['asin']]]['name']
                else:
                    title = 'unknown name'
                if rand_prob > 0.5:
                    source_text = task_template['source'].format(user_desc, int(rating_datum['overall']), title)
                    target_text = task_template['target'].format('yes')
                else:
                    overall_candidates = [_ for _ in range(0 + 1, 5 + 1) if _ != int(rating_datum['overall'])]
                    overall_idx = random.randint(0,
                                                 len(overall_candidates) - 1)  # random choose the overall index for overall_candidates
                    source_text = task_template['source'].format(user_desc, overall_candidates[overall_idx], title)
                    target_text = task_template['target'].format('no')
            elif task_template['id'] == '1-9':
                if 'name' in self.user_data[self.user_meta_dict[rating_datum['reviewerID']]]:
                    user_desc = self.user_data[self.user_meta_dict[rating_datum['reviewerID']]]['name']
                else:
                    user_desc = rating_datum['reviewerID']
                if 'name' in self.meta_data[self.meta_dict[rating_datum['asin']]]:
                    title = self.meta_data[self.meta_dict[rating_datum['asin']]]['name']
                else:
                    title = 'unknown name'
                source_text = task_template['source'].format(user_desc, title)
                if int(rating_datum['overall']) >= 4:
                    target_text = task_template['target'].format('like')
                else:
                    target_text = task_template['target'].format('dislike')
            elif task_template['id'] == '1-10':
                if 'name' in self.user_data[self.user_meta_dict[rating_datum['reviewerID']]]:
                    user_desc = self.user_data[self.user_meta_dict[rating_datum['reviewerID']]]['name']
                else:
                    user_desc = rating_datum['reviewerID']
                if 'name' in self.meta_data[self.meta_dict[rating_datum['asin']]]:
                    title = self.meta_data[self.meta_dict[rating_datum['asin']]]['name']
                else:
                    title = 'unknown name'
                source_text = task_template['source'].format(user_desc, title)
                target_text = task_template['target'].format(self.gaussian_sampling(rating_datum))
            else:
                raise NotImplementedError

        elif task_name == 'sequential':
            sequential_datum = self.sequential_data[datum_idx]
            sequence = sequential_datum.split()
            user_id = sequence[0]
            user_desc = self.user_id2name[user_id]
            history_limit = 20
            if self.mode == 'train':
                end_candidates = [_ for _ in range(max(2, len(sequence) - 6), len(sequence) - 3)]
                end_index = random.randint(0, len(end_candidates) - 1)
                end_pos = end_candidates[end_index]
                start_candidates = [_ for _ in range(1, min(4, end_pos))]
                start_index = random.randint(0, len(start_candidates) - 1)
                start_pos = start_candidates[start_index]
                purchase_history = sequence[start_pos:end_pos + 1]
                target_item = sequence[end_pos + 1]
            elif self.mode == 'val':
                purchase_history = sequence[1:-2]
                target_item = sequence[-2]
            elif self.mode == 'test':
                purchase_history = sequence[1:-1]
                target_item = sequence[-1]
            else:
                raise NotImplementedError
            if len(purchase_history) > history_limit:
                purchase_history = purchase_history[-history_limit:]

            task_candidates = self.task_list[task_name]
            task_idx = random.randint(0, len(task_candidates) - 1)  # random choose the task index for task_candidates
            task_template = self.all_tasks['sequential'][task_candidates[task_idx]]
            assert task_template['task'] == 'sequential'

            if task_template['id'] == '2-1':
                rand_prob = random.random()
                if rand_prob > 0.5:
                    source_text = task_template['source'].format(user_id, ' , '.join(purchase_history))
                else:
                    source_text = task_template['source'].format(user_id, ' -> '.join(purchase_history))
                target_text = task_template['target'].format(target_item)
            elif task_template['id'] == '2-2':
                rand_prob = random.random()
                if rand_prob > 0.5:
                    source_text = task_template['source'].format(user_id, ' , '.join(purchase_history))
                else:
                    source_text = task_template['source'].format(user_id, ' -> '.join(purchase_history))
                target_text = task_template['target'].format(target_item)
            elif task_template['id'] == '2-3':
                rand_prob = random.random()
                if rand_prob > 0.5:
                    source_text = task_template['source'].format(user_id, ' , '.join(purchase_history))
                else:
                    source_text = task_template['source'].format(user_id, ' -> '.join(purchase_history))
                target_text = task_template['target'].format(target_item)
            elif task_template['id'] == '2-4':
                rand_prob = random.random()
                if rand_prob > 0.5:
                    source_text = task_template['source'].format(user_desc, ' , '.join(purchase_history))
                else:
                    source_text = task_template['source'].format(user_desc, ' -> '.join(purchase_history))
                target_text = task_template['target'].format(target_item)
            elif task_template['id'] == '2-5':
                rand_prob = random.random()
                if rand_prob > 0.5:
                    source_text = task_template['source'].format(user_desc, ' , '.join(purchase_history))
                else:
                    source_text = task_template['source'].format(user_desc, ' -> '.join(purchase_history))
                target_text = task_template['target'].format(target_item)
            elif task_template['id'] == '2-6':
                rand_prob = random.random()
                if rand_prob > 0.5:
                    source_text = task_template['source'].format(user_desc, ' , '.join(purchase_history))
                else:
                    source_text = task_template['source'].format(user_desc, ' -> '.join(purchase_history))
                target_text = task_template['target'].format(target_item)
            elif task_template['id'] == '2-7' or task_template['id'] == '2-9':
                if self.mode in ['train', 'val']:
                    user_seq = self.user_items[user_id]
                    candidate_samples = []
                    candidate_num = random.randint(79, 99)
                    while len(candidate_samples) < candidate_num:
                        if self.sample_type == 'random':
                            sample_ids = np.random.choice(self.all_item, candidate_num, replace=False)
                        else:
                            sample_ids = np.random.choice(self.all_item, candidate_num, replace=False,
                                                          p=self.probability)
                        sample_ids = [str(item) for item in sample_ids if item not in user_seq and item not in candidate_samples]
                        candidate_samples.extend(sample_ids)
                    candidate_samples = candidate_samples[:candidate_num]
                elif self.mode == 'test':
                    assert user_id == self.negative_samples[int(user_id) - 1].split(' ', 1)[0]
                    candidate_samples = self.negative_samples[int(user_id) - 1].split(' ', 1)[1].split(' ')
                else:
                    raise NotImplementedError
                candidate_samples.extend([target_item])
                random.shuffle(candidate_samples)
                rand_prob = random.random()
                if rand_prob > 0.5:
                    source_text = task_template['source'].format(user_id, ' , '.join(purchase_history),
                                                                 ' , '.join(candidate_samples))
                else:
                    source_text = task_template['source'].format(user_id, ' -> '.join(purchase_history),
                                                                 ' , '.join(candidate_samples))
                target_text = task_template['target'].format(target_item)
            elif task_template['id'] == '2-8' or task_template['id'] == '2-10':
                if self.mode in ['train', 'val']:
                    user_seq = self.user_items[user_id]
                    candidate_samples = []
                    candidate_num = random.randint(79, 99)
                    while len(candidate_samples) < candidate_num:
                        if self.sample_type == 'random':
                            sample_ids = np.random.choice(self.all_item, candidate_num, replace=False)
                        else:
                            sample_ids = np.random.choice(self.all_item, candidate_num, replace=False,
                                                          p=self.probability)
                        sample_ids = [str(item) for item in sample_ids if
                                      item not in user_seq and item not in candidate_samples]
                        candidate_samples.extend(sample_ids)
                    candidate_samples = candidate_samples[:candidate_num]
                elif self.mode == 'test':
                    assert user_id == self.negative_samples[int(user_id) - 1].split(' ', 1)[0]
                    candidate_samples = self.negative_samples[int(user_id) - 1].split(' ', 1)[1].split(' ')
                else:
                    raise NotImplementedError
                candidate_samples.extend([target_item])
                random.shuffle(candidate_samples)
                rand_prob = random.random()
                if rand_prob > 0.5:
                    source_text = task_template['source'].format(user_desc, ' , '.join(purchase_history),
                                                                 ' , '.join(candidate_samples))
                else:
                    source_text = task_template['source'].format(user_desc, ' -> '.join(purchase_history),
                                                                 ' , '.join(candidate_samples))
                target_text = task_template['target'].format(target_item)
            elif task_template['id'] == '2-11':
                symbol_prob = random.random()
                if symbol_prob > 0.5:
                    symbol = ' , '
                else:
                    symbol = ' -> '
                rand_prob = random.random()
                if rand_prob > 0.5:
                    source_text = task_template['source'].format(user_id, symbol.join(purchase_history), target_item)
                    target_text = task_template['target'].format('yes')
                else:
                    user_seq = self.user_items[user_id]
                    candidate_samples = []
                    candidate_num = 1
                    while len(candidate_samples) < candidate_num:
                        if self.sample_type == 'random':
                            sample_ids = np.random.choice(self.all_item, candidate_num, replace=False)
                        else:
                            sample_ids = np.random.choice(self.all_item, candidate_num, replace=False,
                                                          p=self.probability)
                        sample_ids = [str(item) for item in sample_ids if
                                      item not in user_seq and item not in candidate_samples]
                        candidate_samples.extend(sample_ids)
                    candidate_samples = candidate_samples[:candidate_num]
                    source_text = task_template['source'].format(user_id, symbol.join(purchase_history),
                                                                 candidate_samples[0])
                    target_text = task_template['target'].format('no')
            elif task_template['id'] == '2-12':
                symbol_prob = random.random()
                if symbol_prob > 0.5:
                    symbol = ' , '
                else:
                    symbol = ' -> '
                rand_prob = random.random()
                if rand_prob > 0.5:
                    source_text = task_template['source'].format(user_desc, symbol.join(purchase_history), target_item)
                    target_text = task_template['target'].format('yes')
                else:
                    user_seq = self.user_items[user_id]
                    candidate_samples = []
                    candidate_num = 1
                    while len(candidate_samples) < candidate_num:
                        if self.sample_type == 'random':
                            sample_ids = np.random.choice(self.all_item, candidate_num, replace=False)
                        else:
                            sample_ids = np.random.choice(self.all_item, candidate_num, replace=False,
                                                          p=self.probability)
                        sample_ids = [str(item) for item in sample_ids if
                                      item not in user_seq and item not in candidate_samples]
                        candidate_samples.extend(sample_ids)
                    candidate_samples = candidate_samples[:candidate_num]
                    source_text = task_template['source'].format(user_desc, symbol.join(purchase_history),
                                                                 candidate_samples[0])
                    target_text = task_template['target'].format('no')
            elif task_template['id'] == '2-13':
                rand_prob = random.random()
                if rand_prob > 0.5:
                    source_text = task_template['source'].format(user_desc, ' , '.join(purchase_history))
                else:
                    source_text = task_template['source'].format(user_desc, ' -> '.join(purchase_history))
                target_text = task_template['target'].format(target_item)
            else:
                raise NotImplementedError

        elif task_name == 'explanation':
            exp_datum = self.exp_data[datum_idx]
            task_candidates = self.task_list[task_name]
            task_idx = random.randint(0, len(task_candidates) - 1)  # random choose the task index for task_candidates
            task_template = self.all_tasks['explanation'][task_candidates[task_idx]]
            assert task_template['task'] == 'explanation'

            if task_template['id'] == '3-1':
                if 'name' in self.meta_data[self.meta_dict[exp_datum['asin']]]:
                    title = self.meta_data[self.meta_dict[exp_datum['asin']]]['name']
                else:
                    title = 'unknown name'
                source_text = task_template['source'].format(self.user2id[exp_datum['reviewerID']], title)
                target_text = task_template['target'].format(exp_datum['explanation'])
            elif task_template['id'] == '3-2':
                if 'name' in self.meta_data[self.meta_dict[exp_datum['asin']]]:
                    title = self.meta_data[self.meta_dict[exp_datum['asin']]]['name']
                else:
                    title = 'unknown name'
                source_text = task_template['source'].format(self.user2id[exp_datum['reviewerID']],
                                                             int(exp_datum['overall']), title)
                target_text = task_template['target'].format(exp_datum['explanation'])
            elif task_template['id'] == '3-3':
                if 'name' in self.user_data[self.user_meta_dict[exp_datum['reviewerID']]]:
                    user_desc = self.user_data[self.user_meta_dict[exp_datum['reviewerID']]]['name']
                else:
                    user_desc = exp_datum['reviewerID']
                if 'name' in self.meta_data[self.meta_dict[exp_datum['asin']]]:
                    title = self.meta_data[self.meta_dict[exp_datum['asin']]]['name']
                else:
                    title = 'unknown name'
                source_text = task_template['source'].format(user_desc, title)
                target_text = task_template['target'].format(exp_datum['explanation'])
            elif task_template['id'] == '3-4':
                if 'name' in self.user_data[self.user_meta_dict[exp_datum['reviewerID']]]:
                    user_desc = self.user_data[self.user_meta_dict[exp_datum['reviewerID']]]['name']
                else:
                    user_desc = exp_datum['reviewerID']
                source_text = task_template['source'].format(user_desc, int(exp_datum['overall']),
                                                             self.item2id[exp_datum['asin']])
                target_text = task_template['target'].format(exp_datum['explanation'])
            elif task_template['id'] == '3-5':
                source_text = task_template['source'].format(exp_datum['feature'],
                                                             self.user2id[exp_datum['reviewerID']],
                                                             self.item2id[exp_datum['asin']])
                target_text = task_template['target'].format(self.gaussian_sampling(exp_datum),
                                                             exp_datum['explanation'])
            elif task_template['id'] == '3-6':
                if 'name' in self.user_data[self.user_meta_dict[exp_datum['reviewerID']]]:
                    user_desc = self.user_data[self.user_meta_dict[exp_datum['reviewerID']]]['name']
                else:
                    user_desc = exp_datum['reviewerID']
                source_text = task_template['source'].format(user_desc, self.item2id[exp_datum['asin']])
                target_text = task_template['target'].format(self.gaussian_sampling(exp_datum),
                                                             exp_datum['explanation'])
            elif task_template['id'] == '3-7':
                if 'name' in self.meta_data[self.meta_dict[exp_datum['asin']]]:
                    title = self.meta_data[self.meta_dict[exp_datum['asin']]]['name']
                else:
                    title = 'unknown name'
                source_text = task_template['source'].format(exp_datum['feature'],
                                                             self.user2id[exp_datum['reviewerID']], title)
                target_text = task_template['target'].format(exp_datum['explanation'])
            elif task_template['id'] == '3-8':
                if 'name' in self.user_data[self.user_meta_dict[exp_datum['reviewerID']]]:
                    user_desc = self.user_data[self.user_meta_dict[exp_datum['reviewerID']]]['name']
                else:
                    user_desc = exp_datum['reviewerID']
                if 'name' in self.meta_data[self.meta_dict[exp_datum['asin']]]:
                    title = self.meta_data[self.meta_dict[exp_datum['asin']]]['name']
                else:
                    title = 'unknown name'
                source_text = task_template['source'].format(exp_datum['feature'], user_desc, title)
                target_text = task_template['target'].format(exp_datum['explanation'])
            elif task_template['id'] == '3-9':
                source_text = task_template['source'].format(exp_datum['feature'], int(exp_datum['overall']),
                                                             self.user2id[exp_datum['reviewerID']],
                                                             self.item2id[exp_datum['asin']])
                target_text = task_template['target'].format(exp_datum['explanation'])
            elif task_template['id'] == '3-10':
                if 'name' in self.user_data[self.user_meta_dict[exp_datum['reviewerID']]]:
                    user_desc = self.user_data[self.user_meta_dict[exp_datum['reviewerID']]]['name']
                else:
                    user_desc = exp_datum['reviewerID']
                source_text = task_template['source'].format(exp_datum['feature'], int(exp_datum['overall']), user_desc,
                                                             self.item2id[exp_datum['asin']])
                target_text = task_template['target'].format(exp_datum['explanation'])
            else:
                raise NotImplementedError

        elif task_name == 'review':
            review_datum = self.review_data[datum_idx]
            task_candidates = self.task_list[task_name]
            task_idx = random.randint(0, len(task_candidates) - 1)  # random choose the task index for task_candidates
            task_template = self.all_tasks['review'][task_candidates[task_idx]]
            assert task_template['task'] == 'review'

            if task_template['id'] == '4-1':
                source_text = task_template['source'].format(self.user2id[review_datum['reviewerID']],
                                                             review_datum['reviewText'])
                target_text = task_template['target'].format(int(review_datum['overall']))
            elif task_template['id'] == '4-2':
                source_text = task_template['source'].format(self.user2id[review_datum['reviewerID']],
                                                             review_datum['reviewText'])
                target_text = task_template['target'].format(int(review_datum['overall']))
            elif task_template['id'] == '4-3':
                if 'name' in self.user_data[self.user_meta_dict[review_datum['reviewerID']]]:
                    user_desc = self.user_data[self.user_meta_dict[review_datum['reviewerID']]]['name']
                else:
                    user_desc = review_datum['reviewerID']
                source_text = task_template['source'].format(user_desc, review_datum['reviewText'])
                target_text = task_template['target'].format(int(review_datum['overall']))
            else:
                raise NotImplementedError

        elif task_name == 'traditional':
            sequential_datum = self.sequential_data[datum_idx]
            sequence = sequential_datum.split()
            user_id = sequence[0]
            user_desc = self.user_id2name[user_id]
            if self.mode == 'train':
                target_candidates = sequence[1:-2]
                target_idx = random.randint(0,
                                            len(target_candidates) - 1)  # random choose the target index for target_candidates
                target_item = target_candidates[target_idx]
            elif self.mode == 'val':
                target_item = sequence[-2]
            elif self.mode == 'test':
                target_item = sequence[-1]
            else:
                raise NotImplementedError

            task_candidates = self.task_list[task_name]
            task_idx = random.randint(0, len(task_candidates) - 1)  # random choose the task index for task_candidates
            task_template = self.all_tasks['traditional'][task_candidates[task_idx]]
            assert task_template['task'] == 'traditional'

            if task_template['id'] == '5-1':
                rand_prob = random.random()
                if rand_prob > 0.5:
                    source_text = task_template['source'].format(user_id, target_item)
                    target_text = task_template['target'].format('yes')
                else:
                    user_seq = self.user_items[user_id]
                    candidate_samples = []
                    candidate_num = 1
                    while len(candidate_samples) < candidate_num:
                        if self.sample_type == 'random':
                            sample_ids = np.random.choice(self.all_item, candidate_num, replace=False)
                        else:
                            sample_ids = np.random.choice(self.all_item, candidate_num, replace=False,
                                                          p=self.probability)
                        sample_ids = [str(item) for item in sample_ids if
                                      item not in user_seq and item not in candidate_samples]
                        candidate_samples.extend(sample_ids)
                    candidate_samples = candidate_samples[:candidate_num]
                    source_text = task_template['source'].format(user_id, candidate_samples[0])
                    target_text = task_template['target'].format('no')
            elif task_template['id'] == '5-2':
                rand_prob = random.random()
                if rand_prob > 0.5:
                    source_text = task_template['source'].format(target_item, user_desc)
                    target_text = task_template['target'].format('yes')
                else:
                    user_seq = self.user_items[user_id]
                    candidate_samples = []
                    candidate_num = 1
                    while len(candidate_samples) < candidate_num:
                        if self.sample_type == 'random':
                            sample_ids = np.random.choice(self.all_item, candidate_num, replace=False)
                        else:
                            sample_ids = np.random.choice(self.all_item, candidate_num, replace=False,
                                                          p=self.probability)
                        sample_ids = [str(item) for item in sample_ids if
                                      item not in user_seq and item not in candidate_samples]
                        candidate_samples.extend(sample_ids)
                    candidate_samples = candidate_samples[:candidate_num]
                    source_text = task_template['source'].format(candidate_samples[0], user_desc)
                    target_text = task_template['target'].format('no')
            elif task_template['id'] == '5-3':
                rand_prob = random.random()
                if rand_prob > 0.5:
                    if 'name' in self.meta_data[self.meta_dict[self.id2item[target_item]]]:
                        title = self.meta_data[self.meta_dict[self.id2item[target_item]]]['name']
                    else:
                        title = 'unknown name'
                    source_text = task_template['source'].format(user_desc, title)
                    target_text = task_template['target'].format('yes')
                else:
                    user_seq = self.user_items[user_id]
                    candidate_samples = []
                    candidate_num = 1
                    while len(candidate_samples) < candidate_num:
                        if self.sample_type == 'random':
                            sample_ids = np.random.choice(self.all_item, candidate_num, replace=False)
                        else:
                            sample_ids = np.random.choice(self.all_item, candidate_num, replace=False,
                                                          p=self.probability)
                        sample_ids = [str(item) for item in sample_ids if
                                      item not in user_seq and item not in candidate_samples]
                        candidate_samples.extend(sample_ids)
                    candidate_samples = candidate_samples[:candidate_num]
                    if 'name' in self.meta_data[self.meta_dict[self.id2item[candidate_samples[0]]]]:
                        title = self.meta_data[self.meta_dict[self.id2item[candidate_samples[0]]]]['name']
                    else:
                        title = 'unknown name'
                    source_text = task_template['source'].format(user_desc, title)
                    target_text = task_template['target'].format('no')
            elif task_template['id'] == '5-4':
                rand_prob = random.random()
                if rand_prob > 0.5:
                    if 'name' in self.meta_data[self.meta_dict[self.id2item[target_item]]]:
                        title = self.meta_data[self.meta_dict[self.id2item[target_item]]]['name']
                    else:
                        title = 'unknown name'
                    source_text = task_template['source'].format(user_id, title)
                    target_text = task_template['target'].format('yes')
                else:
                    user_seq = self.user_items[user_id]
                    candidate_samples = []
                    candidate_num = 1
                    while len(candidate_samples) < candidate_num:
                        if self.sample_type == 'random':
                            sample_ids = np.random.choice(self.all_item, candidate_num, replace=False)
                        else:
                            sample_ids = np.random.choice(self.all_item, candidate_num, replace=False,
                                                          p=self.probability)
                        sample_ids = [str(item) for item in sample_ids if
                                      item not in user_seq and item not in candidate_samples]
                        candidate_samples.extend(sample_ids)
                    candidate_samples = candidate_samples[:candidate_num]
                    if 'name' in self.meta_data[self.meta_dict[self.id2item[candidate_samples[0]]]]:
                        title = self.meta_data[self.meta_dict[self.id2item[candidate_samples[0]]]]['name']
                    else:
                        title = 'unknown name'
                    source_text = task_template['source'].format(user_id, title)
                    target_text = task_template['target'].format('no')
            elif task_template['id'] == '5-5' or task_template['id'] == '5-6':
                user_seq = self.user_items[user_id]
                candidate_samples = []
                candidate_num = 99  # random.randint(19, 99)
                while len(candidate_samples) < candidate_num:
                    if self.sample_type == 'random':
                        sample_ids = np.random.choice(self.all_item, candidate_num, replace=False)
                    else:
                        sample_ids = np.random.choice(self.all_item, candidate_num, replace=False, p=self.probability)
                    sample_ids = [str(item) for item in sample_ids if
                                  item not in user_seq and item not in candidate_samples]
                    candidate_samples.extend(sample_ids)
                candidate_samples = candidate_samples[:candidate_num]
                candidate_samples.extend([target_item])
                random.shuffle(candidate_samples)
                source_text = task_template['source'].format(user_desc, ' , '.join(candidate_samples))
                target_text = task_template['target'].format(target_item)
            elif task_template['id'] == '5-7' or task_template['id'] == '5-8':
                user_seq = self.user_items[user_id]
                candidate_samples = []
                candidate_num = 99  # random.randint(19, 99)
                while len(candidate_samples) < candidate_num:
                    if self.sample_type == 'random':
                        sample_ids = np.random.choice(self.all_item, candidate_num, replace=False)
                    else:
                        sample_ids = np.random.choice(self.all_item, candidate_num, replace=False, p=self.probability)
                    sample_ids = [str(item) for item in sample_ids if
                                  item not in user_seq and item not in candidate_samples]
                    candidate_samples.extend(sample_ids)
                candidate_samples = candidate_samples[:candidate_num]
                candidate_samples.extend([target_item])
                random.shuffle(candidate_samples)
                source_text = task_template['source'].format(user_id, ' , '.join(candidate_samples))
                target_text = task_template['target'].format(target_item)
            else:
                raise NotImplementedError

        else:
            raise NotImplementedError

        input_ids = self.tokenizer.encode(source_text, padding=True, truncation=True, max_length=self.args.max_text_length)
        tokenized_text = self.tokenizer.tokenize(source_text)
        whole_word_ids = self.calculate_whole_word_ids(tokenized_text, input_ids)
        assert len(whole_word_ids) == len(input_ids)

        target_ids = self.tokenizer.encode(target_text, padding=True, truncation=True, max_length=self.args.gen_max_length)

        out_dict['idx'] = idx
        out_dict['input_ids'] = torch.LongTensor(input_ids)
        out_dict['input_length'] = len(input_ids)
        out_dict['whole_word_ids'] = torch.LongTensor(whole_word_ids)
        out_dict['target_ids'] = torch.LongTensor(target_ids)
        out_dict['target_length'] = len(target_ids)

        out_dict['source_text'] = source_text
        out_dict['tokenized_text'] = tokenized_text
        out_dict['target_text'] = target_text

        out_dict['task'] = task_template['task']

        out_dict['loss_weight'] = loss_weight

        return out_dict



def get_data(args, task_list, sample_numbers, mode='train'):

    tokenizer = CustomTokenizer.from_pretrained(
        args.model_name,
        max_length = args.max_text_length,
        do_lower_case = args.do_lower_case
    )

    dataset = P5YelpDataset(
        task_templates,
        task_list,
        tokenizer,
        args,
        sample_numbers,
        mode=mode,
        rating_augment=False
    )

    if mode == 'train':
        loader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=16,
            pin_memory=True,
            collate_fn=dataset.collate_fn)
    else:
        loader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            num_workers=16,
            pin_memory=True,
            shuffle=False,
            collate_fn=dataset.collate_fn,
            drop_last=False)

    return loader

#### Build Model

In [None]:
@dataclass
class P5Seq2SeqLMOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    past_key_values: Optional[List[torch.FloatTensor]] = None
    decoder_last_hidden_state: Optional[Tuple[torch.FloatTensor]] = None
    decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None



class JointEncoder(T5Stack):

    def __init__(self, config, embed_tokens=None):
        super(T5Stack, self).__init__(config)
        self.config = config
        self.embed_tokens = embed_tokens
        self.is_decoder = self.config.is_decoder
        assert self.config.is_decoder is False

        self.block = nn.ModuleList(
            [T5Block(config, has_relative_attention_bias=(i == 0)) for i in range(config.num_layers)]
        )
        self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
        self.dropout = nn.Dropout(config.dropout_rate)

        ## Set maximum 512 whole words in a source text
        self.whole_word_embeddings = nn.Embedding(512, config.d_model) ## config.d_model is 768 for base
        self.init_weights()

    def set_input_embeddings(self, new_embeddings):
        self.embed_tokens = new_embeddings

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        inputs_embeds=None,
        head_mask=None,
        cross_attn_head_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        cache_position=None,
        whole_word_ids=None,
    ):

        # input_ids = input_ids.to(DEVICE)
        # attention_mask = attention_mask.to(DEVICE)
        # if labels is not None:
        #     labels = labels.to(DEVICE)

        if inputs_embeds is None:
            assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
            inputs_embeds = self.embed_tokens(input_ids)  ### embedding step - add HERE ###
            if whole_word_ids is not None:
                whole_word_embeds = self.whole_word_embeddings(whole_word_ids)
                assert whole_word_embeds.shape[-1] == inputs_embeds.shape[-1]
                inputs_embeds = inputs_embeds + whole_word_embeds

        B, L = inputs_embeds.size()[:-1]

        if attention_mask is None:
            attention_mask = input_ids.ne(self.config.pad_token_id).to(dtype=inputs_embeds.dtype, device=DEVICE)

        # ourselves in which case we just need to make it broadcastable to all heads.
        extended_attention_mask = self.get_extended_attention_mask(attention_mask,(B, L), DEVICE)

        # initialize past_key_values with `None` if past does not exist
        if past_key_values is None:
            past_key_values = [None] * len(self.block)

        # Prepare head mask if needed
        head_mask = self.get_head_mask(head_mask, self.config.num_layers)
        present_key_value_states = () if use_cache else None
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
        all_cross_attentions = () if (output_attentions and self.is_decoder) else None

        hidden_states = self.dropout(inputs_embeds)

        if self.config.num_layers > 0:
            assert self.block[0].layer[0].SelfAttention.has_relative_attention_bias

            seq_length = L
            q_len = seq_length
            k_len = seq_length

            # [1, n_heads, Q_len, K_len]
            text_position_bias = self.block[0].layer[0].SelfAttention.compute_bias(L, L)
            num_heads = text_position_bias.size(1)
            position_bias = text_position_bias.new_zeros(1, num_heads, seq_length, seq_length)
            position_bias[:, :, :L, :L] = text_position_bias

            position_bias = position_bias + extended_attention_mask

            for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask=extended_attention_mask,
                    position_bias=position_bias,
                    encoder_hidden_states=None,
                    encoder_attention_mask=None,
                    encoder_decoder_position_bias=None,
                    layer_head_mask=head_mask[i],
                    past_key_value=past_key_value,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                )

                if len(layer_outputs)==2:
                    hidden_states, position_bias = layer_outputs
                elif len(layer_outputs)==3:
                    hidden_states, present_key_value_state, position_bias = layer_outputs
                else:
                    raise ValueError("layer_outputs does not have proper length")

                # append next layer key value states
                if use_cache:
                    present_key_value_states = present_key_value_states + (present_key_value_state,)

        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.dropout(hidden_states)

        # Add last layer
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    present_key_value_states,
                    all_hidden_states,
                    all_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=present_key_value_states,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
            cross_attentions=all_cross_attentions,
        )



class P5(T5ForConditionalGeneration):

    def __init__(self, config):
        super(T5ForConditionalGeneration, self).__init__(config)
        self.config = config
        self.model_dim = config.d_model
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
        encoder_config.is_decoder = False
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False
        self.encoder = JointEncoder(encoder_config, self.shared)

        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
        decoder_config.is_encoder_decoder = False
        self.decoder = T5Stack(decoder_config, self.shared)

        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.init_weights()

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)

    def extend_vocab(self, vocab_size):
        new_shared = nn.Embedding(vocab_size, self.config.d_model)
        old_weight = self.shared.weight.data.detach().clone()
        old_vocab_size = old_weight.size(0)
        new_shared.weight.data[:old_vocab_size, :] = old_weight
        self.shared = new_shared

        new_lm_head = nn.Linear(self.config.d_model, vocab_size, bias=False)
        old_weight = self.lm_head.weight.data.detach().clone()
        old_vocab_size = old_weight.size(0)
        new_lm_head.weight.data[:old_vocab_size, :] = old_weight
        self.lm_head = new_lm_head

        self.encoder.embed_tokens = self.shared
        self.decoder.embed_tokens = self.shared
        self.lm_head.weight = self.shared.weight
        self.config.vocab_size = vocab_size
        self.encoder.config.vocab_size = vocab_size
        self.decoder.config.vocab_size = vocab_size

    def forward(
        self,
        input_ids=None,
        whole_word_ids=None,
        attention_mask=None,
        encoder_outputs=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        labels=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        head_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        reduce_loss=False,
        return_hidden_state=False,
        **kwargs,
    ):

        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        input_ids = input_ids.to(DEVICE)
        if attention_mask is not None:
            attention_mask = attention_mask.to(DEVICE)
        if labels is not None:
            labels = labels.to(DEVICE)

        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                whole_word_ids=whole_word_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        hidden_states = encoder_outputs[0]

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # If decoding with past key value states, only the last tokens
        # should be given as an input
        if past_key_values is not None:
            assert labels is None, "Decoder should not use cached key value states when training."
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids[:, -1:]
            if decoder_inputs_embeds is not None:
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

        if attention_mask is None:
            attention_mask = input_ids.ne(self.config.pad_token_id).to(dtype=hidden_states.dtype, device=hidden_states.device)
        encoder_attention_mask = attention_mask

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,

            encoder_hidden_states=hidden_states,
            encoder_attention_mask=encoder_attention_mask,

            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]

        assert self.config.tie_word_embeddings is True

        if self.config.tie_word_embeddings:
            sequence_output = sequence_output * (self.model_dim ** -0.5)

        if return_hidden_state:
            return sequence_output

        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            if reduce_loss:
                loss_fct = CrossEntropyLoss(ignore_index=-100)
            else:
                loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
            loss = loss_fct(
                lm_logits.view(-1, lm_logits.size(-1)),
                labels.view(-1))

        return P5Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_last_hidden_state=decoder_outputs.last_hidden_state,
            decoder_hidden_states=decoder_outputs.hidden_states,
        )

    def prepare_inputs_for_generation(
        self, input_ids, past=None, attention_mask=None, use_cache=None,
        encoder_outputs=None,
        **kwargs):

        if past is not None:
            input_ids = input_ids[:, -1:]

        output = {
            "decoder_input_ids": input_ids,
            "past_key_values": past,
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
            "use_cache": use_cache,
        }

        return output

    @staticmethod
    def _expand_inputs_for_generation(
        input_ids: torch.LongTensor,
        expand_size: int = 1,
        is_encoder_decoder: bool = False,
        attention_mask: torch.LongTensor = None,
        encoder_outputs: ModelOutput = None,
        **model_kwargs
    ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
        expanded_return_idx = (
            torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1,
                                                                expand_size).view(-1).to(input_ids.device)
        )
        input_ids = input_ids.index_select(0, expanded_return_idx)

        if "token_type_ids" in model_kwargs:
            token_type_ids = model_kwargs["token_type_ids"]
            model_kwargs["token_type_ids"] = token_type_ids.index_select(
                0, expanded_return_idx)

        if attention_mask is not None:
            model_kwargs["attention_mask"] = attention_mask.index_select(
                0, expanded_return_idx)

        if is_encoder_decoder:
            assert encoder_outputs is not None
            encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(0, expanded_return_idx)
            model_kwargs["encoder_outputs"] = encoder_outputs

        return input_ids, model_kwargs

#### Trainer Base

In [None]:
class TrainerBase(object):
    def __init__(self, args, train_loader=None, val_loader=None, test_loader=None, train=True):
        self.args = args
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.verbose = True
        if self.args.tokenizer is None:
            self.args.tokenizer = self.args.model_name
        if not self.verbose:
            set_global_logging_level(logging.ERROR, ["transformers"])

    def create_config(self):
        config_class = T5Config
        config = config_class.from_pretrained(self.args.model_name)
        args = self.args
        config.dropout_rate = args.dropout
        config.dropout = args.dropout
        config.attention_dropout = args.dropout
        config.activation_dropout = args.dropout
        config.losses = args.losses
        return config

    def create_model(self, model_class, config=None, **kwargs):
        model_name = self.args.model_name
        model = model_class.from_pretrained(
            model_name,
            config=config,
            **kwargs
        )
        return model

    def create_tokenizer(self, **kwargs):
        tokenizer_class = CustomTokenizer
        tokenizer_name = self.args.model_name
        tokenizer = tokenizer_class.from_pretrained(
            tokenizer_name,
            max_length=self.args.max_text_length,
            do_lower_case=self.args.do_lower_case,
            **kwargs
            )
        return tokenizer

    def create_optimizer_and_scheduler(self):
        if self.verbose:
            print('Building Optimizer')
        lr_scheduler = None

        batch_per_epoch = len(self.train_loader)
        t_total = batch_per_epoch // self.args.gradient_accumulation_steps * self.args.epoch
        warmup_ratio = self.args.warmup_ratio
        warmup_iters = int(t_total * warmup_ratio)
        if self.verbose:
            print("Batch per epoch: %d" % batch_per_epoch)
            print("Total Iters: %d" % t_total)
            print('Warmup ratio:', warmup_ratio)
            print("Warm up Iters: %d" % warmup_iters)

        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.args.weight_decay,
            },
            {
                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optim = AdamW(optimizer_grouped_parameters, lr=self.args.lr, eps=self.args.adam_eps)
        lr_scheduler = get_linear_schedule_with_warmup(optim, warmup_iters, t_total)

        return optim, lr_scheduler

    def load_checkpoint(self, ckpt_path):
        state_dict = load_state_dict(ckpt_path, 'cpu')
        results = self.model.load_state_dict(state_dict, strict=False)
        if self.verbose:
            print('Model loaded from ', ckpt_path)
            pprint(results)

    def init_weights(self):

        def init_bert_weights(module):
            """ Initialize the weights."""
            if isinstance(module, (nn.Linear, nn.Embedding)):
                module.weight.data.normal_(mean=0.0, std=1)
            elif isinstance(module, nn.LayerNorm):
                module.bias.data.zero_()
                module.weight.data.fill_(1.0)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        self.model.apply(init_bert_weights)
        self.model.init_weights()

    def predict(self):
        pass

    def evaluate(self):
        pass

    def save(self, name):
        if not os.path.isdir(self.args.output):
            os.makedirs(self.args.output, exist_ok=True)
        torch.save(self.model.state_dict(), os.path.join(self.args.output, "%s.pth" % name))

    def load(self, path, loc=None):
        if loc is None and hasattr(self.args, 'gpu'):
            loc = f'cuda:{self.args.gpu}'
        state_dict = torch.load("%s.pth" % path, map_location=loc)
        results = self.model.load_state_dict(state_dict, strict=False)
        if self.verbose:
            print('Model loaded from ', path)
            pprint(results)


class P5Pretraining(P5):
    def __init__(self, config):
        super().__init__(config)
        self.losses = self.config.losses

    def train_step(self, batch):
        input_ids = batch['input_ids'].to(DEVICE)
        whole_word_ids = batch['whole_word_ids'].to(DEVICE)
        lm_labels = batch["target_ids"].to(DEVICE)
        loss_weights = batch["loss_weights"].to(DEVICE)
        output = self(
            input_ids=input_ids,
            whole_word_ids=whole_word_ids,
            labels=lm_labels,
            return_dict=True
        )
        assert 'loss' in output

        lm_mask = lm_labels != -100
        lm_mask = lm_mask.float()
        B, L = lm_labels.size()
        loss = output['loss']
        loss = loss.view(B, L) * lm_mask
        loss = loss.sum(dim=1) / lm_mask.sum(dim=1).clamp(min=1)

        task_counts = {task: 0 for task in self.losses}
        task_loss = {task: 0 for task in self.losses}

        results = {}
        results['loss'] = (loss * loss_weights).mean()
        results['total_loss'] = loss.detach().sum()
        results['total_loss_count'] = len(loss)

        task_counts = {task.replace("_loss", ""): 0 for task in self.losses}
        task_loss = {task.replace("_loss", ""): 0 for task in self.losses}
        for _loss, task in zip(loss.detach(), batch['task']):
            task_loss[task] += _loss
            task_counts[task] += 1
        for task in self.losses:
            task = task.replace("_loss", "")
            if task_counts[task] > 0:
                results[f'{task}'] = task_loss[task]
                results[f'{task}_count'] = task_counts[task]

        return results

    @torch.no_grad()
    def valid_step(self, batch):
        self.eval()
        input_ids = batch['input_ids'].to(DEVICE)
        lm_labels = batch["target_ids"].to(DEVICE)
        loss_weights = batch["loss_weights"].to(DEVICE)
        output = self(
            input_ids=input_ids,
            labels=lm_labels,
            return_dict=True
        )
        assert 'loss' in output

        lm_mask = lm_labels != -100
        lm_mask = lm_mask.float()
        B, L = lm_labels.size()
        loss = output['loss']
        loss = loss.view(B, L) * lm_mask
        loss = loss.sum(dim=1) / lm_mask.sum(dim=1).clamp(min=1)

        results = {}
        results['loss'] = (loss * loss_weights).mean()
        results['total_loss'] = loss.detach().sum()
        results['total_loss_count'] = len(loss)

        task_counts = {task: 0 for task in self.losses}
        task_loss = {task: 0 for task in self.losses}
        for _loss, task in zip(loss.detach(), batch['task']):
            task_loss[task] += _loss
            task_counts[task] += 1
        for task in self.losses:
            if task_counts[task] > 0:
                results[f'{task}'] = task_loss[task]
                results[f'{task}_count'] = task_counts[task]

        if 'rating' in self.losses:
            output = self.generate(input_ids=input_ids)
            generated_score = self.tokenizer.batch_decode(output, skip_special_tokens=True)
            results['rating_pred'] = generated_score

        return results

    @torch.no_grad()
    def generate_step(self, batch):
        self.eval()
        input_ids = batch['input_ids'].to(DEVICE)
        output = self.generate(input_ids=input_ids)
        generated_sents = self.tokenizer.batch_decode(output, skip_special_tokens=True)
        return generated_sents



class Trainer(TrainerBase):
    def __init__(self, args, train_loader=None, val_loader=None, test_loader=None, train=True):
        super().__init__(args, train_loader=train_loader, val_loader=val_loader, test_loader=test_loader, train=train)
        assert args.whole_word_embed

        model_kwargs = {}
        model_class = P5Pretraining

        config = self.create_config()
        self.tokenizer = self.create_tokenizer()
        self.model = self.create_model(model_class, config, **model_kwargs)
        if 'p5' in self.args.tokenizer:
            self.model.resize_token_embeddings(self.tokenizer.vocab_size)
        self.model.tokenizer = self.tokenizer

        # Load Checkpoint
        self.start_epoch = None
        # if args.load is not None:
        #     ckpt_path = args.load + '.pth'
        #     self.load_checkpoint(ckpt_path)
        #     self.start_epoch = int(args.load.split('Epoch-')[-1])

        # if self.args.from_scratch:
        #     self.init_weights()
        self.init_weights()

        # GPU Options
        # print(f'Model Launching at GPU {self.args.gpu}')
        # if self.verbose:
        #     from time import time
        #     start = time()
        # self.model = self.model.to(args.gpu)
        self.model = self.model.to(DEVICE)

        # Optimizer
        if train:
            self.optim, self.lr_scheduler = self.create_optimizer_and_scheduler()

        # if self.verbose:
        #     print(f'It took {time() - start:.1f}s')

    def train(self):
        LOSSES_NAME = self.args.losses

        if self.verbose:
            loss_meters = [LossMeter() for _ in range(len(LOSSES_NAME))]
            best_eval_loss = 100000.

        global_step = 0
        for epoch in range(self.args.epoch):
            self.model.train()
            pbar = tqdm(total=len(self.train_loader), ncols=275)

            epoch_results = {}
            for loss_name in LOSSES_NAME:
                epoch_results[loss_name] = 0.
                epoch_results[f'{loss_name}_count'] = 0

            for step_i, batch in enumerate(self.train_loader):
                results = self.model.train_step(batch)
                loss = results['loss']
                loss.backward()
                loss = loss.detach()
                self.optim.step()
                if self.lr_scheduler:
                    self.lr_scheduler.step()
                # self.model.zero_grad()
                for param in self.model.parameters():
                    param.grad = None
                global_step += 1
                lr = self.lr_scheduler.get_lr()[0]

                for k, v in results.items():
                    if k in epoch_results:
                        if isinstance(v, int):
                            epoch_results[k] += v
                        elif isinstance(v, torch.Tensor):
                            epoch_results[k] += v.item()

                if step_i % 20000:
                    desc_str = f'Epoch {epoch} | LR {lr:.6f} |'
                    for i, (loss_name, loss_meter) in enumerate(zip(LOSSES_NAME, loss_meters)):
                        if loss_name in results:
                            loss_meter.update(results[f'{loss_name}'] / results[f'{loss_name}_count'])
                        if len(loss_meter) > 0:
                            loss_count = epoch_results[f'{loss_name}_count']
                            desc_str += f' {loss_name} ({loss_count}) {loss_meter.val:.3f}'
                    pbar.set_description(desc_str)
                    pbar.update(1)


            pbar.close()
            results = sum(epoch_results) / len(epoch_results)
            train_loss = results['total_loss']
            train_loss_count = results['total_loss_count']
            avg_train_loss = train_loss / train_loss_count
            losses_str = f"Train Loss: {avg_train_loss:.3f}\n"

            for name, loss in results.items():
                if name[-4:] == 'loss':
                    loss_count = int(results[name+'_count'])
                    if loss_count > 0:
                        avg_loss = loss/loss_count
                        losses_str += f"{name} ({loss_count}): {avg_loss:.3f} "
            losses_str += '\n'
            print(losses_str)

#### Train

In [None]:
def main_train(args):
    train_task_list = {'rating': ['1-1', '1-2', '1-3', '1-4', '1-5', '1-6', '1-7', '1-8', '1-9'],
                       'sequential': ['2-1', '2-2', '2-3', '2-4', '2-5', '2-6', '2-7', '2-8', '2-9', '2-10', '2-11', '2-12'],
                       'explanation': ['3-1', '3-2', '3-3', '3-4', '3-5', '3-6', '3-7', '3-8', '3-9'],
                       'review': ['4-1', '4-2'],
                       'traditional': ['5-1', '5-2', '5-3', '5-4', '5-5', '5-6', '5-7']
                       }
    train_sample_numbers = {'rating': 1, 'sequential': (5, 5, 10), 'explanation': 1, 'review': 1, 'traditional': (10, 5)}
    train_loader = get_data(
        args,
        train_task_list,
        train_sample_numbers,
        mode='train',
    )


    val_task_list = {'rating': ['1-1', '1-2', '1-3', '1-4', '1-5', '1-6', '1-7', '1-8', '1-9'],
    'sequential': ['2-1', '2-2', '2-3', '2-4', '2-5', '2-6', '2-7', '2-8', '2-9', '2-10', '2-11', '2-12'],
    'explanation': ['3-1', '3-2', '3-3', '3-4', '3-5', '3-6', '3-7', '3-8', '3-9'],
    'review': ['4-1', '4-2'],
    'traditional': ['5-1', '5-2', '5-3', '5-4', '5-5', '5-6', '5-7']
    }
    val_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}
    val_loader = get_data(
        args,
        val_task_list,
        val_sample_numbers,
        mode='val',
    )

    # for batch_idx, (inputs, targets) in enumerate(train_loader):
    #     print(batch_idx, inputs, targets)
    #     break
    # print("here")

    trainer = Trainer(args, train_loader, val_loader, train=True)
    trainer.train()



if __name__ == "__main__":
    args = ModelParams

    # set random state
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    main_train(args)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'CustomTokenizer'.
You are using the default legacy behaviour of the <class '__main__.CustomTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Total number of businesses: stored in self.meta_dict 20033
Total umber of users: stored in self.user_meta_dict 30431
compute_datum_info
Init call finished


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'CustomTokenizer'.


Total number of businesses: stored in self.meta_dict 20033
Total umber of users: stored in self.user_meta_dict 30431
compute_datum_info
Init call finished


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'CustomTokenizer'.


model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

Some weights of P5Pretraining were not initialized from the model checkpoint at t5-base and are newly initialized: ['encoder.whole_word_embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]



Building Optimizer
Batch per epoch: 441316
Total Iters: 441316
Warmup ratio: 0.05
Warm up Iters: 22065


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
Epoch 0 | LR 0.000244 | total_loss (21576) 2.654:   1%|██▏                                                                                                                                                                                | 5393/441316 [16:15<24:55:51,  4.86it/s]