In [None]:
%%writefile /content/colossal/callback.py
import psutil
import torch
import torch.distributed as dist
from pytorch_lightning.callbacks import Callback


def print_rank_0(*args, **kwargs):
    if dist.get_rank() == 0:
        print(*args, **kwargs)
    dist.barrier()


def get_cpu_mem():
    return psutil.Process().memory_info().rss


class MemoryMonitor(Callback):
    def __init__(self) -> None:
        super().__init__()
        self.max_cpu_mem = 0

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None:
        self.max_cpu_mem = max(get_cpu_mem(), self.max_cpu_mem)

    def on_fit_start(self, trainer, pl_module) -> None:
        max_cuda_mem = torch.cuda.max_memory_allocated()
        cuda_mem = torch.cuda.memory_allocated()
        print_rank_0(f'CPU memory before training: {get_cpu_mem()/1024**2:.3f} MB')
        print_rank_0(f'CUDA memory before training: {cuda_mem/1024**2:.3f} MB')
        print_rank_0(f'Max CUDA memory before training: {max_cuda_mem/1024**2:.3f} MB')

    def on_fit_end(self, trainer, pl_module) -> None:
        max_cuda_mem = torch.cuda.max_memory_allocated()
        print_rank_0(f'Max CPU memory: {self.max_cpu_mem/1024**2:.3f} MB')
        print_rank_0(f'Max CUDA memory: {max_cuda_mem/1024**2:.3f} MB')


Writing /home/ubuntu/kevin.jung/colossal/callback.py


In [None]:
%%writefile /content/colossal/data.py
import torch

__all__ = ['RandomDataloader']


def get_data(batch_size, seq_len, vocab_size):
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
    attention_mask = torch.ones_like(input_ids)
    return input_ids, attention_mask


class RandomDataloader:
    def __init__(self, n_steps: int, batch_size: int, seq_len: int = 1024, vocab_size: int = 50257) -> None:
        self.n_steps = n_steps
        self.cur_step = 0
        self.batch_size = batch_size
        self.seq_len = seq_len
        self.vocab_size = vocab_size

    def __iter__(self):
        self.cur_step = 0
        return self

    def __next__(self):
        if self.cur_step >= self.n_steps:
            raise StopIteration
        self.cur_step += 1
        return get_data(self.batch_size, self.seq_len, self.vocab_size)

    def __len__(self):
        return self.n_steps

def get_data_s2s(batch_size, seq_len, vocab_size):
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
    attention_mask = torch.ones_like(input_ids)
    decoder_input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
    decoder_attention_mask = torch.ones_like(decoder_input_ids)
    labels = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
    return input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, labels

class RandomS2SDataloader(RandomDataloader):
    def __next__(self):
        if self.cur_step >= self.n_steps:
            raise StopIteration
        self.cur_step += 1
        return get_data_s2s(self.batch_size, self.seq_len, self.vocab_size)


Writing /home/ubuntu/kevin.jung/colossal/data.py


In [None]:
%%writefile /content/colossal/dataloader.py
import os
import warnings

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch

from datetime import datetime as dt
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning.trainer.supporters import CombinedLoader

from copy import deepcopy
from scipy.stats import poisson

__all__ = ["LanguageDataModule"]


class Pet_Dataset(Dataset):
    def __init__(self,
                 max_seq_len:int,
                 file_path:str,
                 tokenizer_path:str):
        self.file_path = file_path
        self.data = pd.read_csv(self.file_path).dropna()
        self.max_seq_len = max_seq_len
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        self.masking_start, self.masking_end =self.tokenizer.encode("[]")

    def __len__(self):
        return self.data.__len__()

    def _encode(self, text):
        tokens = [self.tokenizer.bos_token] + self.tokenizer.tokenize(text) + [self.tokenizer.eos_token]
        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        attention_mask = [1]*len(input_ids)
        if len(input_ids) < self.max_seq_len:
            while len(input_ids)<self.max_seq_len:
                input_ids+=[self.tokenizer.pad_token_id]
                attention_mask+=[0]
        else:
            input_ids = input_ids[:self.max_seq_len-1]+[self.tokenizer.eos_token_id]
            attention_mask = attention_mask[:self.max_seq_len]
        return input_ids, attention_mask

    def _labeling(self, label):
        tokens = self.tokenizer.tokenize(label)+[self.tokenizer.eos_token]
        label_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        if len(label_ids) < self.max_seq_len:
            while len(label_ids)<self.max_seq_len:
                label_ids+=[-100]
        else:
            label_ids = label_ids[:self.max_seq_len-1] + [self.tokenizer.eos_token_id]
        return label_ids

    def _masking(self, tokens):
        masked = tokens
        mask_idx=[]
        while self.masking_start in masked and self.masking_end in masked:
            start_idx = masked.index(self.masking_start)
            end_idx = masked.index(self.masking_end)+1
            if start_idx < end_idx:
                mask_idx +=[[start_idx,end_idx]]
                masked[start_idx:end_idx] = [self.tokenizer.mask_token_id]*len(masked[start_idx:end_idx])
            else: break
        return masked, mask_idx

    def _label_masking(self, tokens, idx):
        masked = np.array(tokens)
        index,position =[],[]
        for i in idx:
            index+=range(i[0],i[1])
        for n,x in enumerate(tokens):
            position += [n not in index]
        masked[position] = [-100]
        return masked#.tolist()


class Masked_Dataset(Pet_Dataset):
    def _random_masking(self, input, mask=None, ratio=0.15):
        if mask is None:
            mask = self.tokenizer.mask_token_id
        input = np.array(input)
        rand = np.random.rand(input.size)
        mask_arr = (rand < ratio) * (input != self.tokenizer.bos_token_id) * (input != self.tokenizer.eos_token_id)
        input[mask_arr.nonzero()] = mask
        return input

    def __getitem__(self, index):
        record = self.data.iloc[index]
        pattern, label = record["pattern"], record["label"]
        encoder_input_ids, encoder_attention_mask = self._encode(pattern)
        decoder_input_ids, decoder_attention_mask = self._encode(pattern+label)
        labels = self._labeling(pattern+label)
        encoder_input_ids, mask_idx = self._masking(encoder_input_ids)
        encoder_input_ids = self._random_masking(encoder_input_ids)

        return {"input_ids":np.array(encoder_input_ids, dtype=np.int_),
                "attention_mask":np.array(encoder_attention_mask,dtype=np.float32),
                "decoder_input_ids":np.array(decoder_input_ids, dtype=np.int_),
                "decoder_attention_mask":np.array(decoder_attention_mask,dtype=np.float32),
                "labels":np.array(labels,dtype=np.int_)}


class Permutation_Dataset(Pet_Dataset):
    def _random_rotation(self, input):
        input = np.array(input)
        start_idx = np.where(input==self.tokenizer.bos_token_id)[0][0]
        end_idx = np.where(input==self.tokenizer.eos_token_id)[0][0]
        np.random.shuffle(input[start_idx:end_idx])
        return input

    def __getitem__(self, index):
        record = self.data.iloc[index]
        pattern, label = record["pattern"], record["label"]
        encoder_input_ids, encoder_attention_mask = self._encode(pattern)
        decoder_input_ids, decoder_attention_mask = self._encode(pattern+label)
        labels = self._labeling(pattern+label)
        encoder_input_ids = self._random_rotation(encoder_input_ids)

        return {"input_ids":np.array(encoder_input_ids, dtype=np.int_),
                "attention_mask":np.array(encoder_attention_mask,dtype=np.float32),
                "decoder_input_ids":np.array(decoder_input_ids, dtype=np.int_),
                "decoder_attention_mask":np.array(decoder_attention_mask,dtype=np.float32),
                "labels":np.array(labels,dtype=np.int_)}


class Deletion_Dataset(Pet_Dataset):
    def _random_deletion(self, input, ratio=0.15):
        input = np.array(input)
        eos_idx = np.where(input==self.tokenizer.eos_token_id)[0][0]
        rand = np.random.rand(input[:eos_idx].size)
        rand = np.append(rand,[1.]*input[eos_idx:].size)
        del_arr = (rand < ratio) * (input != self.tokenizer.bos_token_id) * (input != self.tokenizer.eos_token_id)
        input = np.delete(input, del_arr.nonzero())
        return np.int_(np.append(input,[self.tokenizer.pad_token_id]*del_arr.nonzero()[0].size))#.tolist()

    def __getitem__(self, index):
        record = self.data.iloc[index]
        pattern, label = record["pattern"], record["label"]
        encoder_input_ids, encoder_attention_mask = self._encode(pattern)
        decoder_input_ids, decoder_attention_mask = self._encode(pattern+label)
        labels = self._labeling(pattern+label)
        encoder_input_ids = self._random_deletion(encoder_input_ids)

        return {"input_ids":np.array(encoder_input_ids, dtype=np.int_),
                "attention_mask":np.array(encoder_attention_mask,dtype=np.float32),
                "decoder_input_ids":np.array(decoder_input_ids, dtype=np.int_),
                "decoder_attention_mask":np.array(decoder_attention_mask,dtype=np.float32),
                "labels":np.array(labels,dtype=np.int_)}


class Infilling_Dataset(Pet_Dataset):
    def _random_infilling(self, input, ratio=0.15, l=3):
        input = np.array(input)
        eos_idx = np.where(input==self.tokenizer.eos_token_id)[0][0]
        text_range = input[1:eos_idx]
        poi = poisson(l).pmf(text_range) > ratio
        infill = np.where(poi, np.array([self.tokenizer.mask_token_id]+[-100]*(poi.size-1)), text_range)
        infill = np.append(input[0], np.delete(infill,np.where(infill == -100)))
        infill = np.append(infill, [self.tokenizer.pad_token_id]*(np.count_nonzero(poi)-1))
        infill = np.append(infill, input[eos_idx:])
        if poi.max() == poi.min():
            if not len(infill) > len(input):
                infill = np.insert(infill, np.random.choice(len(text_range)),
                                   self.tokenizer.mask_token_id)[:len(input)]
            else:
                _size = int(len(text_range)*ratio)
                infill = np.insert(infill, np.random.choice(len(text_range),
                                                            size=_size), self.tokenizer.mask_token_id)[:len(input)-1]
                infill = np.append(infill, [self.tokenizer.eos_token_id])
        if infill.size == (self.max_seq_len-1):
            infill = np.append(infill, input[-1])
        return np.int_(infill)

    def __getitem__(self, index):
        record = self.data.iloc[index]
        pattern, label = record["pattern"], record["label"]
        encoder_input_ids, encoder_attention_mask = self._encode(pattern)
        decoder_input_ids, decoder_attention_mask = self._encode(pattern+label)
        labels = self._labeling(pattern+label)
        encoder_input_ids = self._random_infilling(encoder_input_ids)

        return {"input_ids":np.array(encoder_input_ids, dtype=np.int_),
                "attention_mask":np.array(encoder_attention_mask,dtype=np.float32),
                "decoder_input_ids":np.array(decoder_input_ids, dtype=np.int_),
                "decoder_attention_mask":np.array(decoder_attention_mask,dtype=np.float32),
                "labels":np.array(labels,dtype=np.int_)}

class RandomMultiToeknFilling_Dataset(
    Masked_Dataset, Permutation_Dataset, Deletion_Dataset, Infilling_Dataset):
    def _random_pattern(self, input):
        _func = np.random.choice([self._random_masking,
                                  self._random_rotation,
                                  self._random_deletion,
                                  self._random_infilling], 1)[0]
        return _func(input=input)

    def __getitem__(self, index):
        record = self.data.iloc[index]
        pattern, label = record["pattern"], record["label"]
        encoder_input_ids, encoder_attention_mask = self._encode(pattern)
        decoder_input_ids, decoder_attention_mask = self._encode(pattern+label)
        labels = self._labeling(pattern+label)
        encoder_input_ids = self._random_pattern(encoder_input_ids)
        input_ids = np.array(encoder_input_ids, dtype=np.int_)
        attention_mask = np.array(encoder_attention_mask,dtype=np.float32)
        decoder_input_ids = np.array(decoder_input_ids, dtype=np.int_)
        decoder_attention_mask = np.array(decoder_attention_mask,dtype=np.float32)
        labels = np.array(labels,dtype=np.int_)
        return input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, labels


class LanguageDataModule(pl.LightningDataModule):
    def __init__(self,
                 train_file:str,
                 val_file:str,
                 test_file:str,
                 tokenizer_path:str,
                 max_seq_len:int=1024,
                 batch_size:int=4,
                 num_workers:int=0,
                 pinned:bool=True):
        super().__init__()
        self.batch_size = batch_size
        self.max_seq_len = max_seq_len
        self.train_file_path = train_file
        self.val_file_path = val_file
        self.test_file_path = test_file
        self.tokenizer_path = tokenizer_path
        self.num_workers = num_workers
        self.pinned = pinned

    def _load_multiple(self, file_path, shuffle):
        return {
            "Deletion Data": DataLoader(
                Deletion_Dataset(
                    max_seq_len=self.max_seq_len,
                    file_path=file_path,
                    tokenizer_path=self.tokenizer_path),
                pin_memory=self.pinned,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                shuffle=shuffle),
            "Permutation Data": DataLoader(
                Permutation_Dataset(
                    max_seq_len=self.max_seq_len,
                    file_path=file_path,
                    tokenizer_path=self.tokenizer_path),
                pin_memory=self.pinned,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                shuffle=shuffle),
            "Masked Data": DataLoader(
                Masked_Dataset(
                    max_seq_len=self.max_seq_len,
                    file_path=file_path,
                    tokenizer_path=self.tokenizer_path),
                pin_memory=self.pinned,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                shuffle=shuffle),
            "Infilling Data": DataLoader(
                Infilling_Dataset(
                    max_seq_len=self.max_seq_len,
                    file_path=file_path,
                    tokenizer_path=self.tokenizer_path),
                pin_memory=self.pinned,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                shuffle=shuffle)
        }

    def _load_random_token(self, file_path, shuffle):
        return DataLoader(
            RandomMultiToeknFilling_Dataset(
                max_seq_len=self.max_seq_len,
                file_path=file_path,
                tokenizer_path=self.tokenizer_path),
            pin_memory=self.pinned,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=shuffle)

    def setup(self, stage):
        # self.train = self._load_multiple(file_path=self.train_file_path, shuffle=True)
        # self.val = self._load_multiple(file_path=self.val_file_path, shuffle=False)
        # self.test = self._load_multiple(file_path=self.test_file_path, shuffle=False)
        self.train = self._load_random_token(file_path=self.train_file_path, shuffle=True)
        self.val = self._load_random_token(file_path=self.val_file_path, shuffle=False)
        self.test = self._load_random_token(file_path=self.test_file_path, shuffle=False)

    def train_dataloader(self):
        return self.train
        # return CombinedLoader(self.train, mode="max_size_cycle")

    def val_dataloader(self):
        return self.val
        # return CombinedLoader(self.val, mode="max_size_cycle")

    def test_dataloader(self):
        return self.test
        # return CombinedLoader(self.test, mode="max_size_cycle")

In [None]:
%%writefile /content/colossal/colossal/model.py
import torch.nn as nn
import pytorch_lightning as pl
from transformers import GPT2Config, GPT2LMHeadModel, GPT2PreTrainedModel
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import colo_set_process_memory_fraction
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
from torch.optim import Adam, Optimizer
from functools import partial
from typing import Callable, Iterable
from contextlib import contextmanager
__all__ = ['GPTLitModule', 'get_optimizer']


@contextmanager
def no_init_weights():
    def dummy_fn(*args):
        return
    try:
        old_init_weights = GPT2PreTrainedModel._init_weights
        GPT2PreTrainedModel._init_weights = dummy_fn
        yield
    finally:
        GPT2PreTrainedModel._init_weights = old_init_weights


class GPTLMModel(nn.Module):
    def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257, checkpoint=False):
        super().__init__()
        self.checkpoint = checkpoint
        with no_init_weights():
            self.model = GPT2LMHeadModel(GPT2Config(n_embd=hidden_size, n_layer=num_layers,
                                                    n_head=num_attention_heads, n_positions=max_seq_len, n_ctx=max_seq_len, vocab_size=vocab_size))
        if checkpoint:
            self.model.gradient_checkpointing_enable()

    def forward(self, input_ids, attention_mask):
        # Only return lm_logits
        return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]


def gpt2_tiny(checkpoint=True):
    return GPTLMModel(hidden_size=128, num_layers=4, num_attention_heads=4, checkpoint=checkpoint)


def gpt2_small(checkpoint=True):
    return GPTLMModel(hidden_size=768, num_layers=12, num_attention_heads=12, checkpoint=checkpoint)


def gpt2_medium(checkpoint=True):
    return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)


def gpt2_large(checkpoint=True):
    return GPTLMModel(hidden_size=1280, num_layers=36, num_attention_heads=20, checkpoint=checkpoint)


def gpt2_xl(checkpoint=True):
    return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=25, checkpoint=checkpoint)


def gpt2_2B(checkpoint=True):
    return GPTLMModel(hidden_size=2048, num_layers=40, num_attention_heads=16, checkpoint=checkpoint)


def gpt2_3B(checkpoint=True):
    return GPTLMModel(hidden_size=2304, num_layers=48, num_attention_heads=16, checkpoint=checkpoint)


def gpt2_4B(checkpoint=True):
    return GPTLMModel(hidden_size=2304, num_layers=64, num_attention_heads=16, checkpoint=checkpoint)


def gpt2_6B(checkpoint=True):
    return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)


def gpt2_8B(checkpoint=True):
    return GPTLMModel(hidden_size=3072, num_layers=72, num_attention_heads=24, checkpoint=checkpoint)


def gpt2_12B(checkpoint=True):
    return GPTLMModel(hidden_size=4096, num_layers=60, num_attention_heads=16, checkpoint=checkpoint)


def gpt2_15B(checkpoint=True):
    return GPTLMModel(hidden_size=4096, num_layers=78, num_attention_heads=16, checkpoint=checkpoint)


def gpt2_18B(checkpoint=True):
    return GPTLMModel(hidden_size=4096, num_layers=90, num_attention_heads=16, checkpoint=checkpoint)


def gpt2_20B(checkpoint=True):
    return GPTLMModel(hidden_size=8192, num_layers=25, num_attention_heads=16, checkpoint=checkpoint)


def gpt2_24B(checkpoint=True):
    return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)


def gpt2_28B(checkpoint=True):
    return GPTLMModel(hidden_size=8192, num_layers=35, num_attention_heads=16, checkpoint=checkpoint)


def gpt2_32B(checkpoint=True):
    return GPTLMModel(hidden_size=8192, num_layers=40, num_attention_heads=16, checkpoint=checkpoint)


def gpt2_36B(checkpoint=True):
    return GPTLMModel(hidden_size=8192, num_layers=45, num_attention_heads=16, checkpoint=checkpoint)


def gpt2_40B(checkpoint=True):
    return GPTLMModel(hidden_size=8192, num_layers=50, num_attention_heads=16, checkpoint=checkpoint)


def gpt2_45B(checkpoint=True):
    return GPTLMModel(hidden_size=8192, num_layers=56, num_attention_heads=16, checkpoint=checkpoint)


def gpt3(checkpoint=True):
    return GPTLMModel(max_seq_len=2048, hidden_size=12288, num_layers=96, num_attention_heads=96, checkpoint=checkpoint)


def get_gpt_model(model_name: str, checkpoint: bool = True) -> nn.Module:
    model_map = {
        'gpt2_tiny': gpt2_tiny,
        'gpt2_small': gpt2_small,
        'gpt2_medium': gpt2_medium,
        'gpt2_large': gpt2_large,
        'gpt2_xl': gpt2_xl,
        'gpt2_2B': gpt2_2B,
        'gpt2_3B': gpt2_3B,
        'gpt2_4B': gpt2_4B,
        'gpt2_6B': gpt2_6B,
        'gpt2_8B': gpt2_8B,
        'gpt2_12B': gpt2_12B,
        'gpt2_15B': gpt2_15B,
        'gpt2_18B': gpt2_18B,
        'gpt2_20B': gpt2_20B,
        'gpt2_24B': gpt2_24B,
        'gpt2_28B': gpt2_28B,
        'gpt2_32B': gpt2_32B,
        'gpt2_36B': gpt2_36B,
        'gpt2_40B': gpt2_40B,
        'gpt2_45B': gpt2_45B,
        'gpt3': gpt3,
    }
    assert model_name in model_map
    return model_map[model_name](checkpoint)


class GPTLMLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss = nn.CrossEntropyLoss()

    def forward(self, logits, labels):
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))


def get_optimizer(strategy: str, **kwargs) -> Callable[[Iterable], Optimizer]:
    assert strategy in ('ddp', 'deepspeed', 'colossal')
    if strategy == 'ddp':
        opt_cls = Adam
    elif strategy == 'deepspeed':
        offload = kwargs.pop('offload')
        if offload:
            opt_cls = DeepSpeedCPUAdam
        else:
            opt_cls = FusedAdam
    else:
        opt_cls = HybridAdam
    return partial(opt_cls, **kwargs)


class GPTLitModule(pl.LightningModule):
    def __init__(self,
                 model_name: str,
                 optimizer_init_fn: Callable[[Iterable], Optimizer],
                 checkpoint: bool = True,
                 cuda_mem_fraction: float = 1.0,
                 model_checkpoint_dir: str = None) -> None:
        super().__init__()
        self.model_name = model_name
        self.optimizer_init_fn = optimizer_init_fn
        self.checkpoint = checkpoint
        self.criterion = GPTLMLoss()
        self.cuda_mem_fraction = cuda_mem_fraction
        self.model_checkpoint = model_checkpoint_dir

    def configure_sharded_model(self) -> None:
        self.model = get_gpt_model(self.model_name, self.checkpoint)

    def on_load_checkpoint(self, checkpoint) -> None:
        if not hasattr(self, 'model'):
            self.configure_sharded_model()
        if self.model_checkpoint:
            print(f"Load Checkpoint from {self.model_checkpoint}")
            self.model.model.load_state_dict(
                get_fp32_state_dict_from_zero_checkpoint(self.model_checkpoint))

    def configure_optimizers(self):
        return self.optimizer_init_fn(self.model.parameters())

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask = batch
        logits = self.model(input_ids, attention_mask)
        loss = self.criterion(logits, input_ids)
        return loss

    def on_fit_start(self) -> None:
        if self.cuda_mem_fraction < 1.0:
            colo_set_process_memory_fraction(self.cuda_mem_fraction)


Writing /home/ubuntu/kevin.jung/colossal/model.py


In [None]:
%%writefile /content/colossal/s2s_model.py
import gc
import torch
import os
import torch.nn as nn
import pytorch_lightning as pl
from transformers import BartForConditionalGeneration, BartModel, BartConfig, BartPretrainedModel
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import colo_set_process_memory_fraction

import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
from deepspeed.accelerator import get_accelerator
from torch.optim import Adam, Optimizer
from functools import partial
from typing import Callable, Iterable
from contextlib import contextmanager
__all__ = ['S2SLitModule', 'get_optimizer']


@contextmanager
def no_init_weights():
    def dummy_fn(*args):
        return
    try:
        old_init_weights = BartPretrainedModel._init_weights
        BartPretrainedModel._init_weights = dummy_fn
        yield
    finally:
        BartPretrainedModel._init_weights = old_init_weights


class S2SLMModel(nn.Module):
    def __init__(self,
                 hidden_size:int=768,
                 num_layers:int=12,
                 num_attention_heads:int=12,
                 prompt_layers:int=3,
                 max_seq_len:int=1024,
                 vocab_size:int=64512, # 50257
                 checkpoint:bool=False):
        super().__init__()
        self.checkpoint = checkpoint
        with no_init_weights():
            self.model = BartForConditionalGeneration(
                BartConfig(
                    d_model=hidden_size,
                    encoder_layers=prompt_layers,
                    decoder_layers=num_layers,
                    encoder_head=num_attention_heads,
                    decoder_head=num_attention_heads,
                    max_position_embeddings=max_seq_len,
                    n_ctx=max_seq_len,
                    vocab_size=vocab_size))
        if checkpoint:
            self.model.gradient_checkpointing_enable()

    def forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask):
        # Only return lm_logits
        return self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            use_cache=not self.checkpoint)[0]


def s2s_tiny(checkpoint=True):
    return S2SLMModel(hidden_size=128, num_layers=4, num_attention_heads=4, checkpoint=checkpoint)


def s2s_small(checkpoint=True):
    return S2SLMModel(hidden_size=768, num_layers=12, num_attention_heads=12, checkpoint=checkpoint)


def s2s_medium(checkpoint=True):
    return S2SLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)


def s2s_large(checkpoint=True):
    return S2SLMModel(hidden_size=1280, num_layers=36, num_attention_heads=20, checkpoint=checkpoint)


def s2s_xl(checkpoint=True):
    return S2SLMModel(hidden_size=1600, num_layers=48, num_attention_heads=25, checkpoint=checkpoint)


def s2s_2B(checkpoint=True):
    return S2SLMModel(hidden_size=2048, num_layers=40, num_attention_heads=16, checkpoint=checkpoint)


def s2s_9B(checkpoint=True):
    return S2SLMModel(hidden_size=2048, num_layers=178, num_attention_heads=16, checkpoint=checkpoint)


def s2s_3B(checkpoint=True):
    return S2SLMModel(hidden_size=2304, num_layers=48, num_attention_heads=16, checkpoint=checkpoint)


def s2s_4B(checkpoint=True):
    return S2SLMModel(hidden_size=2304, num_layers=64, num_attention_heads=16, checkpoint=checkpoint)


def s2s_6B(checkpoint=True):
    return S2SLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)


def s2s_8B(checkpoint=True):
    return S2SLMModel(hidden_size=3072, num_layers=72, num_attention_heads=24, checkpoint=checkpoint)


def s2s_12B(checkpoint=True):
    return S2SLMModel(hidden_size=4096, num_layers=60, num_attention_heads=16, checkpoint=checkpoint)


def s2s_15B(checkpoint=True):
    return S2SLMModel(hidden_size=4096, num_layers=78, num_attention_heads=16, checkpoint=checkpoint)


def s2s_18B(checkpoint=True):
    return S2SLMModel(hidden_size=4096, num_layers=90, num_attention_heads=16, checkpoint=checkpoint)


def s2s_20B(checkpoint=True):
    return S2SLMModel(hidden_size=8192, num_layers=25, num_attention_heads=16, checkpoint=checkpoint)


def s2s_24B(checkpoint=True):
    return S2SLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)


def s2s_28B(checkpoint=True):
    return S2SLMModel(hidden_size=8192, num_layers=35, num_attention_heads=16, checkpoint=checkpoint)


def s2s_32B(checkpoint=True):
    return S2SLMModel(hidden_size=8192, num_layers=40, num_attention_heads=16, checkpoint=checkpoint)


def s2s_36B(checkpoint=True):
    return S2SLMModel(hidden_size=8192, num_layers=45, num_attention_heads=16, checkpoint=checkpoint)


def s2s_40B(checkpoint=True):
    return S2SLMModel(hidden_size=8192, num_layers=50, num_attention_heads=16, checkpoint=checkpoint)


def s2s_45B(checkpoint=True):
    return S2SLMModel(hidden_size=8192, num_layers=56, num_attention_heads=16, checkpoint=checkpoint)


def s2s_LLM(checkpoint=True):
    return S2SLMModel(max_seq_len=2048, hidden_size=12288, num_layers=96, num_attention_heads=96, checkpoint=checkpoint)


def get_s2s_model(model_name: str, checkpoint: bool = True) -> nn.Module:
    model_map = {
        's2s_tiny': s2s_tiny,
        's2s_small': s2s_small,
        's2s_medium': s2s_medium,
        's2s_large': s2s_large,
        's2s_xl': s2s_xl,
        's2s_2B': s2s_2B,
        's2s_9.2B': s2s_9B,
        's2s_3B': s2s_3B,
        's2s_4B': s2s_4B,
        's2s_6B': s2s_6B,
        's2s_8B': s2s_8B,
        's2s_12B': s2s_12B,
        's2s_15B': s2s_15B,
        's2s_18B': s2s_18B,
        's2s_20B': s2s_20B,
        's2s_24B': s2s_24B,
        's2s_28B': s2s_28B,
        's2s_32B': s2s_32B,
        's2s_36B': s2s_36B,
        's2s_40B': s2s_40B,
        's2s_45B': s2s_45B,
        's2s_LLM': s2s_LLM,
    }
    assert model_name in model_map
    # print(f"Training model is {model_name}")
    return model_map[model_name](checkpoint)


class GPTLMLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss = nn.CrossEntropyLoss()

    def forward(self, logits, labels):
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))


def get_optimizer(strategy: str, **kwargs) -> Callable[[Iterable], Optimizer]:
    assert strategy in ('ddp', 'deepspeed', 'colossal')
    if strategy == 'ddp':
        opt_cls = Adam
    elif strategy == 'deepspeed':
        offload = kwargs.pop('offload')
        if offload:
            opt_cls = DeepSpeedCPUAdam
        else:
            opt_cls = FusedAdam
    else:
        opt_cls = HybridAdam
    return partial(opt_cls, **kwargs)


class S2SLitModule(pl.LightningModule):
    def __init__(self,
                 model_name: str,
                 optimizer_init_fn: Callable[[Iterable], Optimizer],
                 checkpoint: bool = True,
                 cuda_mem_fraction: float = 1.0,
                 model_checkpoint_dir: str = None) -> None:
        super().__init__()
        self.model_name = model_name
        self.optimizer_init_fn = optimizer_init_fn
        self.checkpoint = checkpoint
        self.criterion = GPTLMLoss()
        self.cuda_mem_fraction = cuda_mem_fraction
        self.model_checkpoint = model_checkpoint_dir

    def _save_in_hub_(self)->None:
        # print(self.model.model)
        self.model.model.save_pretrained(
            save_directory="/content/llm_checkpoint/",#
            # use_temp_dir=False,
            push_to_hub=True,
            max_shard_size="124MB",
            # safe_serialization=True,
            repo_id=os.getenv("MODEL_SAVE_REPO"),
            use_auth_token=os.getenv("HUGGINGFACE_AUTO_TOKEN"))

    def __memory_clean__(self)->None:
        get_accelerator().empty_cache()
        torch.cuda.empty_cache()
        gc.collect()

    def configure_sharded_model(self) -> None:
        self.model = get_s2s_model(
            model_name=self.model_name,
            checkpoint=self.checkpoint)

    # def on_save_checkpoint(self, checkpoint)->None:
    #     self._save_in_hub_()

    def on_load_checkpoint(self, checkpoint) -> None:
        if not hasattr(self, 'model'):
            self.configure_sharded_model()
        if self.model_checkpoint:
            print(f"Load Checkpoint from {self.model_checkpoint}")
            self.model.model.load_state_dict(
                get_fp32_state_dict_from_zero_checkpoint(self.model_checkpoint))

    def configure_optimizers(self):
        return self.optimizer_init_fn(self.model.parameters())

    def training_step(self, batch, batch_idx):
        if type(batch) is dict:
            loss_list= torch.empty(0,device=torch.cuda.current_device())
            for loader in batch:
                input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, labels = batch[loader].values()
                logits = self.model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        decoder_input_ids=decoder_input_ids,
                        decoder_attention_mask=decoder_attention_mask)
                loss = self.criterion(logits, labels)
                loss_list = torch.cat((loss_list, loss.view(-1)))
            loss = loss_list.sum()
            return loss
        else:
            input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, labels = batch
            logits = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask)
            loss = self.criterion(logits, input_ids)
            return loss

    def validation_step(self, batch, batch_idx):
        if type(batch) is dict:
            loss_list= torch.empty(0,device=torch.cuda.current_device())
            for loader in batch:
                input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, labels = batch[loader].values()
                logits = self.model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        decoder_input_ids=decoder_input_ids,
                        decoder_attention_mask=decoder_attention_mask)
                loss = self.criterion(logits, labels)
                loss_list = torch.cat((loss_list, loss.view(-1)))
            loss = loss_list.sum()
        else:
            input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, labels = batch
            logits = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask)
            loss = self.criterion(logits, labels)

    def on_training_batch_end(self, outputs, batch, batch_idx):
        self.__memory_clean__()

    def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0):
        self.__memory_clean__()

    def on_fit_start(self) -> None:
        if self.cuda_mem_fraction < 1.0:
            colo_set_process_memory_fraction(self.cuda_mem_fraction)


In [None]:
%%writefile /content/colossal/colossal/train.py
import pytorch_lightning as pl
import argparse
import warnings
import logging
from data import RandomDataloader, RandomS2SDataloader
from dataloader import LanguageDataModule
from model import GPTLitModule, get_optimizer
from s2s_model import S2SLitModule
from callback import MemoryMonitor
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.strategies.deepspeed import DeepSpeedStrategy
from pytorch_lightning.strategies.colossalai import ColossalAIStrategy
# from pytorch_lightning.plugins.deepspeed import Deepspeed
import warnings
warnings.filterwarnings('ignore')

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--tqdm_rate', type=int, default=2000)
    parser.add_argument('--epochs', type=int, default=2)
    parser.add_argument('--steps_per_epoch', type=int, default=4)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--model', default='gpt2_xl')
    parser.add_argument('--np', type=int, default=1)
    parser.add_argument('--no_activation_ckpt', action='store_true', default=False)
    parser.add_argument('--opt_nvme_offload_frac', type=float, default=0.0)
    parser.add_argument('--opt_nvme_offload_dir', default='./offload')
    parser.add_argument('--seq_len', type=int, default=1024)
    parser.add_argument('--placement_policy', default='cuda')
    parser.add_argument('--opt_gpu_margin_rat', type=float, default=0.0)
    parser.add_argument('--cuda_mem_frac', type=float, default=1.0)
    parser.add_argument('--strategy', default='ddp', choices=['ddp', 'colossal', 'deepspeed'])
    parser.add_argument('--offload', action='store_true', default=False)
    parser.add_argument('--model_checkpoint_dir', default='/data/llm_checkpoint/')
    parser.add_argument('--model_zero_ckpt_dir', default=None)
    args = parser.parse_args()

    if "gpt" in args.model:
        lit_module = GPTLitModule
        train_dataloader = RandomDataloader(args.steps_per_epoch, args.batch_size, args.seq_len)
        data_module = None
    else:
        lit_module = S2SLitModule
        train_dataloader = None # RandomS2SDataloader(args.steps_per_epoch, args.batch_size, args.seq_len)
        data_module = LanguageDataModule(
            train_file="/data/kevin.jung/Train.csv",
            val_file="/data/kevin.jung/Dev_s.csv",
            test_file="/data/kevin.jung/Test.csv",
            tokenizer_path="Gunulhona/tb_tokenizer_big",
            max_seq_len=args.seq_len,
            batch_size=args.batch_size)

    optimizer_cfg = {
        'lr': args.lr
        }

    if args.strategy == 'ddp':
        trainer_cfg = {
            'accelerator': 'gpu',
            'precision': 16,
            'strategy': DDPStrategy(static_graph=True)
        }

    elif args.strategy == 'colossal':
        trainer_cfg = {
            'accelerator': 'gpu',
            'precision': 16,
            'strategy': ColossalAIStrategy(
                placement_policy=args.placement_policy,
                gpu_margin_mem_ratio=args.opt_gpu_margin_rat,
                initial_scale=32,
                chunk_search_range= 64 * 1024**2,
                chunk_search_n_grids= 4096,
                min_chunk_size= 32 * 1024**2)
            }

        optimizer_cfg['nvme_offload_dir'] = args.opt_nvme_offload_dir
        optimizer_cfg['nvme_offload_fraction'] = args.opt_nvme_offload_frac

    elif args.strategy == 'deepspeed':
        trainer_cfg = {
            'accelerator': 'gpu',
            'precision': 16,
            'strategy': DeepSpeedStrategy(
                stage=3,
                offload_parameters=args.offload,
                offload_optimizer=args.offload,
                initial_scale_power=5,
                load_full_weights=True,
                logging_batch_size_per_gpu=args.batch_size,
                logging_level=logging.ERROR) # 로그에 warning 너무 많이 쌓여서 추가
            }

        optimizer_cfg['offload'] = args.offload

    opt_init_fn = get_optimizer(args.strategy, **optimizer_cfg)

    model = lit_module(
        model_name=args.model,
        optimizer_init_fn=opt_init_fn,
        checkpoint=not args.no_activation_ckpt,
        cuda_mem_fraction=args.cuda_mem_frac,
        model_checkpoint_dir=args.model_zero_ckpt_dir)

    trainer = pl.Trainer(
        max_epochs=args.epochs,
        devices=args.np,
        enable_checkpointing=True,
        callbacks=[
            MemoryMonitor(),
            TQDMProgressBar(
                refresh_rate=args.tqdm_rate),
            ModelCheckpoint(
                dirpath=args.model_checkpoint_dir,
                mode="min",
                monitor="loss",
                filename="llm-{epoch:02d}-{val_loss:.4f}.ckpt",
                every_n_train_steps=1,
                save_last=True)],
        fast_dev_run=False,
        profiler="advanced",
        **trainer_cfg)

    trainer.fit(
        model=model,
        train_dataloaders=train_dataloader,
        datamodule=data_module)


Overwriting /home/ubuntu/kevin.jung/colossal/train.py


In [None]:
%%writefile /content/colossal/colossal/train_start
# export CUDA_LAUNCH_BLOCKING="1"
# export CUDA_VISIBLE_DEVICES="0,1,2,3"
export TOKENIZERS_PARALLELISM="0"


EXECUTEFILE="colossal/train.py"                         # needs custom trainer path
EPOCHS=100                                              # type=int       default=2
TQDM_RATE=2000                                          # type=int       default=2000
LEARNING_RATE=5e-5                                      # type=float     default=1e-3
STRATEGY="colossal"                                     # type=str       default='ddp'         choices=['ddp', 'colossal', 'deepspeed']
ACCELERATOR="gpu"                                       # type=str       default=gpu
NP=-1                                                   # type=int       default=1
BATCHSIZE=1                                             # type=int       default=1
MODEL_NAME='gpt2_2B'                                    # type=str       default='gpt2_xl'     choices=['gpt2_tiny'~'gpt2_xl'~'gpt3']
STEPS_PER_EPOCH=4                                       # type=int       default=4
NAC=false                                               # type=bool      default=False         action='store_true'
OFFLOAD=false                                           # type=bool      default=False         action='store_true'
OPT_NVME_OFFLAND_FRAC=0.0                               # type=float     default=0.0
OPT_NVME_OFFLAND_DIR='/data/opt/'                       # type=str       default='/data/opt'
SEQ_LEN=1024                                            # type=int       default=1024
PLACEMENT_POLICY='cuda'                                 # type=str       defualt='cuda'
OPT_GPU_MARGIN_RAT=0.0                                  # type=float     defualt=0.0
CUDA_MEMORY_FRAC=1.0                                    # type=float     defualt=1.0
MODEL_ZERO_CKPT_DIR='/data/llm_checkpoint/last.ckpt'    # type=str               default=None

python $EXECUTEFILE\
  --tqdm_rate $TQDM_RATE\
  --model $MODEL_NAME\
  --epochs $EPOCHS\
  --steps_per_epoch $STEPS_PER_EPOCH\
  --batch_size $BATCHSIZE\
  --seq_len $SEQ_LEN\
  --cuda_mem_frac $CUDA_MEMORY_FRAC\
  --np $NP\
  --strategy $STRATEGY\
  --placement_policy $PLACEMENT_POLICY\
  --lr $LEARNING_RATE \
  --no_activation_ckpt\
  --offload\
  --opt_nvme_offload_frac $OPT_NVME_OFFLAND_FRAC\
  --opt_nvme_offload_dir $OPT_NVME_OFFLAND_DIR\
  --opt_gpu_margin_rat $OPT_GPU_MARGIN_RAT\
  --model_zero_ckpt_dir $MODEL_ZERO_CKPT_DIR


Overwriting /home/ubuntu/kevin.jung/colossal/train_start


In [None]:
# file upload to huggingface
from huggingface_hub import HfApi

api=HfApi()

api.upload_folder(
    repo_id="REPOID",
    folder_path="/content/llm_checkpoint/last.ckpt/",
    repo_type="model",
    token='HUGGINFACE_AUTH_TOKEN')