# Gpt-J Summariser Model

In [None]:
!pip install torch
!pip install accelerate
!pip install transformers
!pip install pandas

[0mCollecting accelerate
  Downloading accelerate-0.21.0-py3-none-any.whl (244 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m244.2/244.2 kB[0m [31m15.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting numpy>=1.17 (from accelerate)
  Downloading numpy-1.24.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.3/17.3 MB[0m [31m140.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting psutil (from accelerate)
  Downloading psutil-5.9.5-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (282 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m282.1/282.1 kB[0m [31m81.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: psutil, numpy, accelerate
Successfully installed accelerate-0.21.0 numpy-1.24.4 psutil-5.9.5
[0mCollecting transformers
  Downloading transformers-4.31.0-py3-none-any.whl (7.4 MB)
[2K     

In [None]:
from torch.utils.data import Dataset
import pandas as pd
import torch


def balance_dataframe(df, column_label):
    count_classes = df[column_label].value_counts()
    min_class = min(count_classes)
    balanced_df = pd.DataFrame()

    for class_index, _ in count_classes.items():
        class_df = df[df[column_label] == class_index]
        balanced_class_df = class_df.sample(min_class)
        balanced_df = balanced_df.append(balanced_class_df)
    return balanced_df

def prompt_tokenize(prompt, completion, tokenizer, max_len, truncation=True, padding=True):
    prompt_toks =  tokenizer.encode(prompt)

    completion_toks = tokenizer.encode(completion)
    if truncation:
        prompt_toks = prompt_toks[:max_len - len(completion_toks)]
    sample = torch.tensor(prompt_toks + completion_toks, dtype=int).unsqueeze(0)
    loss_mask = torch.zeros((1, sample.shape[1]), dtype=bool)
    loss_mask[:, list(range(len(prompt_toks), len(prompt_toks) + len(completion_toks)))] = True
    attention_mask = torch.ones(sample.shape, dtype=int)
    if padding:
        pad_zeros = torch.nn.ConstantPad1d((0, max_len - sample.shape[1]), 0)
        pad_eos = torch.nn.ConstantPad1d((0, max_len - sample.shape[1]), tokenizer.pad_token_id)

        sample = pad_eos(sample)
        loss_mask = pad_zeros(loss_mask)
        attention_mask = pad_zeros(attention_mask)
    return sample, attention_mask, loss_mask


class PromptDataset(Dataset):

    @staticmethod
    def create_prompt(text):
        prompt =  f''' Classify the following messages into one of the following categories: [Hate Speech], [Offensive language], [Neutral]

Message: {text}

Category: '''
        return prompt


    def __init__(self, data_df, tokenizer, max_prompt_len=100, truncation=True, padding=True):
        self.df = data_df
        self.tokenizer = tokenizer
        self.max_prompt_len = max_prompt_len
        self.truncation = truncation
        self.padding = padding

    def __getitem__(self, idx):

        data = self.df.iloc[idx]
        prompt = data['prompt']
        completion = data['completion']
        input_ids, attention_mask, loss_mask = prompt_tokenize(prompt, completion, self.tokenizer, self.max_prompt_len, self.truncation, self.padding)
        return  input_ids, attention_mask, loss_mask

    def __len__(self):
        return len(self.df)


In [None]:
!pip install bitsandbytes
!pip install scipy

Collecting bitsandbytes
  Downloading bitsandbytes-0.41.0-py3-none-any.whl (92.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.6/92.6 MB[0m [31m60.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.41.0
[0mCollecting scipy
  Downloading scipy-1.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (34.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.5/34.5 MB[0m [31m113.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Installing collected packages: scipy
Successfully installed scipy-1.10.1
[0m

In [None]:
import transformers

import torch
import torch.nn.functional as F
from torch import nn
from torch.cuda.amp import custom_fwd, custom_bwd

from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise

from tqdm.auto import tqdm


class FrozenBNBLinear(nn.Module):
    def __init__(self, weight, absmax, code, bias=None):
        assert isinstance(bias, nn.Parameter) or bias is None
        super().__init__()
        self.out_features, self.in_features = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None
        self.bias = bias

    def forward(self, input):
        output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)
        output = output.clone()
        if self.adapter:
            output += self.adapter(input)
        return output

    @classmethod
    def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
        weights_int8, state = quantize_blockise_lowmemory(linear.weight)
        return cls(weights_int8, *state, linear.bias)

    def __repr__(self):
        return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"


class DequantizeAndLinear(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
                absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        ctx.save_for_backward(input, weights_quantized, absmax, code)
        ctx._has_bias = bias is not None
        return F.linear(input, weights_deq, bias)

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output: torch.Tensor):
        assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
        input, weights_quantized, absmax, code = ctx.saved_tensors
        # grad_output: [*batch, out_features]
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        grad_input = grad_output @ weights_deq
        grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
        return grad_input, None, None, None, grad_bias


class FrozenBNBEmbedding(nn.Module):
    def __init__(self, weight, absmax, code):
        super().__init__()
        self.num_embeddings, self.embedding_dim = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None

    def forward(self, input, **kwargs):
        with torch.no_grad():
            # note: both quantuized weights and input indices are *not* differentiable
            weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
            output = F.embedding(input, weight_deq, **kwargs)
        if self.adapter:
            output += self.adapter(input)
        return output

    @classmethod
    def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
        weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
        return cls(weights_int8, *state)

    def __repr__(self):
        return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"


def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
    assert chunk_size % 4096 == 0
    code = None
    chunks = []
    absmaxes = []
    flat_tensor = matrix.view(-1)
    for i in range((matrix.numel() - 1) // chunk_size + 1):
        input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
        quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
        chunks.append(quantized_chunk)
        absmaxes.append(absmax_chunk)

    matrix_i8 = torch.cat(chunks).reshape_as(matrix)
    absmax = torch.cat(absmaxes)
    return matrix_i8, (absmax, code)


def convert_to_int8(model):
    """Convert linear and embedding modules to 8-bit with optional adapters"""
    for module in list(model.modules()):
        for name, child in module.named_children():
            if isinstance(child, nn.Linear):
                print(name, child)
                setattr(
                    module,
                    name,
                    FrozenBNBLinear(
                        weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                        bias=child.bias,
                    ),
                )
            elif isinstance(child, nn.Embedding):
                setattr(
                    module,
                    name,
                    FrozenBNBEmbedding(
                        weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                    )
                )

class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
    def __init__(self, config):
        super().__init__(config)

        convert_to_int8(self.attn)
        convert_to_int8(self.mlp)


class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)


class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)

  warn("The installed version of bitsandbytes was compiled without GPU support. "


/usr/local/lib/python3.8/dist-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32


In [None]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.0.1-py3-none-any.whl (729 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m729.2/729.2 kB[0m [31m30.7 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.7.0 (from torchmetrics)
  Downloading lightning_utilities-0.9.0-py3-none-any.whl (23 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.9.0 torchmetrics-1.0.1
[0m

In [None]:
!pip install pytorch_lightning

Collecting pytorch_lightning
  Downloading pytorch_lightning-2.0.6-py3-none-any.whl (722 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m722.8/722.8 kB[0m [31m31.2 MB/s[0m eta [36m0:00:00[0m
Collecting aiohttp!=4.0.0a0,!=4.0.0a1 (from fsspec[http]>2021.06.0->pytorch_lightning)
  Downloading aiohttp-3.8.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m143.3 MB/s[0m eta [36m0:00:00[0m
Collecting multidict<7.0,>=4.5 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning)
  Downloading multidict-6.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (121 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m121.3/121.3 kB[0m [31m57.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting async-timeout<5.0,>=4.0.0a3 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning)
  Downloading async_timeout-4.0.2-py3-

In [None]:
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torchmetrics
from torch.nn.functional import cross_entropy
from pytorch_lightning.loggers import WandbLogger
import pytorch_lightning as pl
import transformers
from bitsandbytes.optim import Adam8bit




transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock

@dataclass
class FinetunerConfig():
    lr: float = 1e-3
    batch_size: int = 1
    warmup_steps: int = 0
    num_epochs: int = 1
    adapter_dim: int = 1
    classification: bool = False

class GPTJ8bitFineTuner(pl.LightningModule):
    def __init__(self, model_name, model_post_init_func, fine_tuning_config, train_dataset, val_dataset=None):
        super().__init__()
        self.model = GPTJForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True)
        self.config = fine_tuning_config
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.validation_step_outputs = []
        self.training_step_outputs = []
        if model_post_init_func:
            model_post_init_func(self.model)

    def forward(self, input_ids, attention_masks):
        return self.model.forward(
                            input_ids=input_ids,
                            attention_mask=attention_masks
                            )

    def common_step(self, batch, batch_idx):
        input_ids, attention_masks, loss_mask = batch

        out = self(
                    input_ids=input_ids,
                    attention_masks=attention_masks
                    )


        logits = out.logits[loss_mask.roll(shifts=-1, dims=2)]
        completion_tok_ids = input_ids[loss_mask]
        loss = cross_entropy(logits, completion_tok_ids)
        preds = None
        if self.config.classification:
            preds = torch.argmax(logits, dim=1)

        return loss, preds, completion_tok_ids


    def training_step(self, batch, batch_idx):
        loss, _, _ = self.common_step(batch, batch_idx)
        self.training_step_outputs.append(loss)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, preds, labels = self.common_step(batch, batch_idx)
        self.log('val_loss', loss)
        self.validation_step_outputs.append(loss)
        if self.config.classification:
            trues = torch.sum(preds == labels).cpu()
            total = len(labels)
            return loss, trues, total
        return loss, None, None

    def on_validation_epoch_end(self):
        epoch_average = torch.stack(self.validation_step_outputs).mean()
        self.log("validation_epoch_average", epoch_average)
        self.validation_step_outputs.clear()  # free memory


    def train_dataloader(self):
        train_loader = torch.utils.data.DataLoader(self.train_dataset,
                                                   batch_size=self.config.batch_size,
                                                  shuffle=True)
        return train_loader

    def val_dataloader(self):
        if self.val_dataset:
            val_dataloader = torch.utils.data.DataLoader(self.val_dataset,
                                                    batch_size=self.config.batch_size,
                                                    shuffle=True)
            return val_dataloader


    def configure_optimizers(self):
        optimizer = Adam8bit(self.model.parameters(), lr=self.config.lr)

        return optimizer


In [None]:
import torch


def add_all_adapters(model, adapter_dim=2):
    assert adapter_dim > 0

    for module in model.modules():
        if isinstance(module, FrozenBNBLinear):
            module.adapter = torch.nn.Sequential(
                torch.nn.Linear(module.in_features, adapter_dim, bias=False),
                # torch.nn.Dropout(p=0.1),
                torch.nn.Linear(adapter_dim, module.out_features, bias=False),
            )
            torch.nn.init.zeros_(module.adapter[1].weight)
        elif isinstance(module, FrozenBNBEmbedding):
            module.adapter = torch.nn.Sequential(
                torch.nn.Embedding(module.num_embeddings, adapter_dim),
                torch.nn.Linear(adapter_dim, module.embedding_dim, bias=False),
            )
            torch.nn.init.zeros_(module.adapter[1].weight)

def add_attention_adapters(model, adapter_dim=2):
    assert adapter_dim > 0

    for name, module in model.named_modules():
        if isinstance(module, FrozenBNBLinear):
            if "attn" in name:
                print("Adding adapter to", name)
                module.adapter = torch.nn.Sequential(
                        torch.nn.Linear(module.in_features, adapter_dim, bias=False),
                        torch.nn.Linear(adapter_dim, module.out_features, bias=False)
                        )
                torch.nn.init.zeros_(module.adapter[1].weight)

In [None]:
!pip install wandb

Collecting wandb
  Downloading wandb-0.15.7-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m64.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting Click!=8.0.0,>=7.1 (from wandb)
  Downloading click-8.1.6-py3-none-any.whl (97 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m97.9/97.9 kB[0m [31m54.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.32-py3-none-any.whl (188 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m188.5/188.5 kB[0m [31m81.4 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-1.28.1-py2.py3-none-any.whl (214 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m214.7/214.7 kB[0m [31m89.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting pat

In [None]:
!pip install datasets

Collecting datasets
  Downloading datasets-2.14.0-py3-none-any.whl (492 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m492.2/492.2 kB[0m [31m24.0 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow>=8.0.0 (from datasets)
  Downloading pyarrow-12.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (39.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m39.0/39.0 MB[0m [31m104.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m59.5 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (213 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m213.0/213.0 kB[0m [31m82.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downlo

In [None]:
from functools import partial

import wandb
import pytorch_lightning as pl

from pytorch_lightning.loggers import WandbLogger
import pandas as pd
import transformers

transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock
from datasets import load_dataset
dataset = load_dataset('cnn_dailymail','3.0.0', split='train')
data = {
    'article': dataset['article'],
    'highlights': dataset['highlights'],
    'id': dataset['id']
}

# Create a DataFrame from the extracted data
train_df = pd.DataFrame(data)
sample_size = 6400  # Specify the desired sample size
train_df = train_df.sample(n=sample_size, random_state=42)
# Print the first few rows of the DataFrame
print(train_df.head())
dataset_1=load_dataset('cnn_dailymail','3.0.0', split='validation')
data_1 = {
    'article': dataset_1['article'],
    'highlights': dataset_1['highlights'],
    'id': dataset_1['id']
}

# Create a DataFrame from the extracted data
val_df = pd.DataFrame(data_1)
sample_size1 = 298  # Specify the desired sample size
val_df = val_df.sample(n=sample_size, random_state=42)
print(val_df.head())

Downloading builder script:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/9.88k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/15.1k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/5 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/159M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/376M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/12.3M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/661k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/572k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

                                                  article  \
272581  Nasa has warned of an impending asteroid pass ...   
772     BAGHDAD, Iraq (CNN) -- Iraq's most powerful Su...   
171868  By . David Kent . Andy Carroll has taken an un...   
63167   Los Angeles (CNN) -- Los Angeles has long been...   
68522   London (CNN) -- Few shows can claim such an au...   

                                               highlights  \
272581  2004 BL86 will pass about three times the dist...   
772     Iraqi Islamic Party calls Quran incident "blat...   
171868  Carroll takes to Instagram to post selfie ahea...   
63167   Pop stars from all over Europe are setting the...   
68522   NEW: Young athletes light the Olympic cauldron...   

                                              id  
272581  6ccb7278e86893ad3609d30ecb5c9ea902fb9527  
772     d4f57e3c18c38696345fb7a3d76a151bb9c5123b  
171868  c9ae9fc314adcc92d3835b0437a1c44e9e233e1c  
63167   5b5a383dc8f9487857787ced5426154394dd99db  
68522   281

In [None]:

def create_summarization_instruction_prompt(text):
    prompt = f'''Summarize the following text:

Text: {text}

Summary:'''
    return prompt

def create_summarization_raw_prompt(highlights):
    prompt = f'{highlights} /n/n###/n/n'
    return prompt

# For classification task we need 1 token completion. The completion token must be in model vocabulary.
# GPT tokenization required completion tokens started with whitespace.
train_df['prompt'] = train_df['article'].apply(create_summarization_raw_prompt)
val_df['prompt'] = val_df['article'].apply(create_summarization_raw_prompt)
batch_size = 128
train_df['completion'] = train_df['highlights'].apply(lambda x: ' ' + x)
val_df['completion'] = val_df['highlights'].apply(lambda x: ' ' + x)
train_df['summary'] = train_df['highlights']
val_df['summary'] = val_df['highlights']


train_df.head()

Unnamed: 0,article,highlights,id,prompt,completion,summary
272581,Nasa has warned of an impending asteroid pass ...,2004 BL86 will pass about three times the dist...,6ccb7278e86893ad3609d30ecb5c9ea902fb9527,Nasa has warned of an impending asteroid pass ...,2004 BL86 will pass about three times the dis...,2004 BL86 will pass about three times the dist...
772,"BAGHDAD, Iraq (CNN) -- Iraq's most powerful Su...","Iraqi Islamic Party calls Quran incident ""blat...",d4f57e3c18c38696345fb7a3d76a151bb9c5123b,"BAGHDAD, Iraq (CNN) -- Iraq's most powerful Su...","Iraqi Islamic Party calls Quran incident ""bla...","Iraqi Islamic Party calls Quran incident ""blat..."
171868,By . David Kent . Andy Carroll has taken an un...,Carroll takes to Instagram to post selfie ahea...,c9ae9fc314adcc92d3835b0437a1c44e9e233e1c,By . David Kent . Andy Carroll has taken an un...,Carroll takes to Instagram to post selfie ahe...,Carroll takes to Instagram to post selfie ahea...
63167,Los Angeles (CNN) -- Los Angeles has long been...,Pop stars from all over Europe are setting the...,5b5a383dc8f9487857787ced5426154394dd99db,Los Angeles (CNN) -- Los Angeles has long been...,Pop stars from all over Europe are setting th...,Pop stars from all over Europe are setting the...
68522,London (CNN) -- Few shows can claim such an au...,NEW: Young athletes light the Olympic cauldron...,2813505a990ad24071496c0d0936e40847eb6194,London (CNN) -- Few shows can claim such an au...,NEW: Young athletes light the Olympic cauldro...,NEW: Young athletes light the Olympic cauldron...


In [None]:
# Created torch Datasets with prepared finetuning samples
# Loaded tokenizer and added padding token
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer.pad_token = tokenizer.eos_token

# Define max_prompt_size. Used to pad short prompts and truncate large prompts. Need for batching or fitting VRAM.
# We will take 0.99 quantile tokenized prompt length plus 5 token for completion.

max_prompt_size = int(pd.Series(len(tokenizer.tokenize(e)) for e in (train_df['prompt'] + ' ' + train_df['completion'])).quantile(0.99)) + 5

train_dataset = PromptDataset(train_df, tokenizer, max_prompt_len=max_prompt_size)
val_dataset = PromptDataset(val_df, tokenizer, max_prompt_len=max_prompt_size)

Downloading (…)okenizer_config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.37M [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/4.04k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/357 [00:00<?, ?B/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (2104 > 2048). Running this sequence through the model will result in indexing errors


In [None]:
!pip install wandb

[0m

In [None]:
import wandb
from pytorch_lightning.loggers import WandbLogger
from functools import partial


# Logging in to WandB
wandb.login()


wandb.init(
    # set the wandb project where this run will be logged
    project = "WANDB_PROJECT",

    # track hyperparameters and run metadata
    config = FinetunerConfig(
    lr=1e-4,
    batch_size=2,
    num_epochs=3,
    adapter_dim=2,
    classification=True
    )


)


config = FinetunerConfig(
    lr=1e-4,
    batch_size=2,
    num_epochs=3,
    adapter_dim=2,
    classification=True
    )

# Choose a way to finetune (Adapters for all linear layers including embedding)
model_post_init_func = partial(add_all_adapters, adapter_dim=2)

# Create the GPTJ8bitFineTuner instance
finetuner = GPTJ8bitFineTuner(
    model_name="hivemind/gpt-j-6B-8bit",
    model_post_init_func=model_post_init_func,
    fine_tuning_config=config,
    train_dataset=train_dataset,
    val_dataset=val_dataset
)

wandb.finish()






k_proj Linear(in_features=4096, out_features=4096, bias=False)
v_proj Linear(in_features=4096, out_features=4096, bias=False)
q_proj Linear(in_features=4096, out_features=4096, bias=False)
out_proj Linear(in_features=4096, out_features=4096, bias=False)
fc_in Linear(in_features=4096, out_features=16384, bias=True)
fc_out Linear(in_features=16384, out_features=4096, bias=True)
k_proj Linear(in_features=4096, out_features=4096, bias=False)
v_proj Linear(in_features=4096, out_features=4096, bias=False)
q_proj Linear(in_features=4096, out_features=4096, bias=False)
out_proj Linear(in_features=4096, out_features=4096, bias=False)
fc_in Linear(in_features=4096, out_features=16384, bias=True)
fc_out Linear(in_features=16384, out_features=4096, bias=True)
k_proj Linear(in_features=4096, out_features=4096, bias=False)
v_proj Linear(in_features=4096, out_features=4096, bias=False)
q_proj Linear(in_features=4096, out_features=4096, bias=False)
out_proj Linear(in_features=4096, out_features=4096, 

In [None]:
model = finetuner.model.to('cpu')
model.eval()

GPTJForCausalLM(
  (transformer): GPTJModel(
    (wte): FrozenBNBEmbedding(50400, 4096)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0): GPTJBlock(
        (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): GPTJAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): FrozenBNBLinear(4096, 4096)
          (v_proj): FrozenBNBLinear(4096, 4096)
          (q_proj): FrozenBNBLinear(4096, 4096)
          (out_proj): FrozenBNBLinear(4096, 4096)
        )
        (mlp): GPTJMLP(
          (fc_in): FrozenBNBLinear(4096, 16384)
          (fc_out): FrozenBNBLinear(16384, 4096)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
      (1): GPTJBlock(
        (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): GPTJAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
 

In [None]:
# Test sample
prompt = '''Summarise the Following Message
Message: TFootball superstar, celebrity, fashion icon, multimillion-dollar heartthrob. Now, David Beckham is headed for the Hollywood Hills as he takes his game to U.S. Major League Soccer.

Summary:'''

sample = tokenizer(prompt, return_tensors='pt')

# Now, move the tensor to the IPU
sample = {k: v.to('cpu') for k, v in sample.items()}


In [None]:
# Generate 1 token after prompt
gen_tokens = model.generate(**sample,
               temperature=0.2,
               do_sample=True,
               max_length=(sample['input_ids'].shape[-1]) + 20)
print(tokenizer.decode(gen_tokens[0]))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Summarise the Following Message
Message: TFootball superstar, celebrity, fashion icon, multimillion-dollar heartthrob. Now, David Beckham is headed for the Hollywood Hills as he takes his game to U.S. Major League Soccer.

Summary:

David Beckham is a professional footballer who has played for the English Premier League club, Manchester United


In [None]:
prompt = '''Summarise the Following Message
Message: SAN FRANCISCO, California (CNN) -- A magnitude 4.2 earthquake shook the San Francisco area Friday at 4:42 a.m. PT (7:42 a.m. ET), the U.S. Geological Survey reported. The quake left about 2,000 customers without power, said David Eisenhower, a spokesman for Pacific Gas and Light. Under the USGS classification, a magnitude 4.2 earthquake is considered "light," which it says usually causes minimal damage. "We had quite a spike in calls, mostly calls of inquiry, none of any injury, none of any damage that was reported," said Capt. Al Casciato of the San Francisco police.

Summary:'''

sample_1 = tokenizer(prompt, return_tensors='pt')

# Now, move the tensor to the IPU
sample_1 = {k: v.to('cpu') for k, v in sample_1.items()}

In [None]:
gen_tokens = model.generate(**sample_1,
                            temperature=0.2,
                            do_sample=True,
                            max_length=(sample_1['input_ids'].shape[-1]) + 20)

# Decode and print the generated tokens
print(tokenizer.decode(gen_tokens[0]))


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Summarise the Following Message
Message: SAN FRANCISCO, California (CNN) -- A magnitude 4.2 earthquake shook the San Francisco area Friday at 4:42 a.m. PT (7:42 a.m. ET), the U.S. Geological Survey reported. The quake left about 2,000 customers without power, said David Eisenhower, a spokesman for Pacific Gas and Light. Under the USGS classification, a magnitude 4.2 earthquake is considered "light," which it says usually causes minimal damage. "We had quite a spike in calls, mostly calls of inquiry, none of any injury, none of any damage that was reported," said Capt. Al Casciato of the San Francisco police.

Summary:

The earthquake was centered about 10 miles (16 kilometers) below the surface, the USGS


In [None]:
prompt = '''Summarise the Following Message
Message:A virus found in healthy Australian honey bees may be playing a role in the collapse of honey bee colonies across the United States, researchers reported Thursday. Honey bees walk on a moveable comb hive at the Bee Research Laboratory, in Beltsville, Maryland. Colony collapse disorder has killed millions of bees -- up to 90 percent of colonies in some U.S. beekeeping operations -- imperiling the crops largely dependent upon bees for pollination, such as oranges, blueberries, apples and almonds. The U.S. Department of Agriculture says honey bees are responsible for pollinating $15 billion worth of crops each year in the United States. More than 90 fruits and vegetables worldwide depend on them for pollination. Signs of colony collapse disorder were first reported in the United States in 2004, the same year American beekeepers started importing bees from Australia. The disorder is marked by hives left with a queen, a few newly hatched adults and plenty of food, but the worker bees responsible for pollination gone. The virus identified in the healthy Australian bees is Israeli Acute Paralysis Virus (IAPV) -- named that because it was discovered by Hebrew University researchers. Although worker bees in colony collapse disorder vanish, bees infected with IAPV die close to the hive, after developing shivering wings and paralysis. For some reason, the Australian bees seem to be resistant to IAPV and do not come down with symptoms. Scientists used genetic analyses of bees collected over the past three years and found that IAPV was present in bees that had come from colony collapse disorder hives 96 percent of the time. But the study released Thursday on the Science Express Web site, operated by the journal Science, cautioned that collapse disorder is likely caused by several factors. "This research give us a very good lead to follow, but we do not believe IAPV is acting alone," said Jeffery S. Pettis of the U.S. Department of Agriculture's Bee Research Laboratory and a co-author of the study. "Other stressors on the colony are likely involved." This could explain why bees in Australia may be resistant to colony collapse.

Summary:'''

sample_2 = tokenizer(prompt, return_tensors='pt')

# Now, move the tensor to the IPU
sample_2 = {k: v.to('cpu') for k, v in sample_2.items()}

In [None]:
gen_tokens = model.generate(**sample_2,
                            temperature=0.2,
                            do_sample=True,
                            max_length=(sample_2['input_ids'].shape[-1]) + 60)

# Decode and print the generated tokens
print(tokenizer.decode(gen_tokens[0]))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Summarise the Following Message
Message:A virus found in healthy Australian honey bees may be playing a role in the collapse of honey bee colonies across the United States, researchers reported Thursday. Honey bees walk on a moveable comb hive at the Bee Research Laboratory, in Beltsville, Maryland. Colony collapse disorder has killed millions of bees -- up to 90 percent of colonies in some U.S. beekeeping operations -- imperiling the crops largely dependent upon bees for pollination, such as oranges, blueberries, apples and almonds. The U.S. Department of Agriculture says honey bees are responsible for pollinating $15 billion worth of crops each year in the United States. More than 90 fruits and vegetables worldwide depend on them for pollination. Signs of colony collapse disorder were first reported in the United States in 2004, the same year American beekeepers started importing bees from Australia. The disorder is marked by hives left with a queen, a few newly hatched adults and pl