# Libraries

In [None]:
import torch
import transformers
from transformers import BloomTokenizerFast, BloomForCausalLM, TrainingArguments

from datasets import load_dataset,list_datasets

from utils import ModifiedTrainer, tokenise_data, data_collator

In [None]:
# dataset = load_dataset('deepmind/code_contests',cache_dir='/cpfs01/user/Wuchen2023/datasets')
# dataset = load_dataset('codeparrot/github-code-clean',cache_dir='/cpfs01/user/Wuchen2023/datasets')

 # Main

In [None]:
print(torch.cuda.is_available())
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
model_name = "bloom-7b1"
model = BloomForCausalLM.from_pretrained(f"/cpfs01/user/Wuchen2023/models/{model_name}")
tokenizer = BloomTokenizerFast.from_pretrained(f"/cpfs01/user/Wuchen2023/models/{model_name}", add_prefix_space=True)

In [None]:
dataset = load_dataset('tatsu-lab/alpaca',cache_dir='/cpfs01/user/Wuchen2023/datasets')

In [None]:
input_ids = tokenise_data(dataset, tokenizer)

In [None]:
model.gradient_checkpointing_enable()
model.is_parallelizable = True
model.model_parallel = True

training_args = TrainingArguments(
    "output",
    fp16=False,
    gradient_accumulation_steps= 1,
    per_device_train_batch_size = 2,
    learning_rate = 2e-5,
    num_train_epochs=2,
    logging_steps=10,
)

trainer = ModifiedTrainer(
    model=model,
    train_dataset=input_ids,
    args=training_args,
    data_collator=data_collator,
)

trainer.train()

In [None]:
prompt = "It was a dark and stormy night"
result_length = 50
inputs = tokenizer(prompt, return_tensors="pt")

In [None]:
# Greedy Search
print(tokenizer.decode(model.generate(inputs["input_ids"], 
                       max_length=result_length
                      )[0]))

In [None]:
# Beam Search
print(tokenizer.decode(model.generate(inputs["input_ids"],
                       max_length=result_length, 
                       num_beams=2, 
                       no_repeat_ngram_size=2,
                       early_stopping=True
                      )[0]))

In [None]:
# Sampling Top-k + Top-p
print(tokenizer.decode(model.generate(inputs["input_ids"],
                       max_length=result_length, 
                       do_sample=True, 
                       top_k=50, 
                       top_p=0.9
                      )[0]))