<a href="https://colab.research.google.com/github/robgon-art/DeepHaiku/blob/main/Deep_Haiku_Generator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Initialize the System

!pip install --upgrade --no-cache-dir gdown
!git clone https://github.com/unitaryai/detoxify
!pip install transformers==4.16.2
!pip install bitsandbytes-cuda111
!git clone https://github.com/robgon-art/GRUEN
!pip install wmd
!pip install --upgrade --no-cache-dir gdown
!gdown --id 1S-l0L_YOzn5KhYHdB8iS37qKwuUhHP0G
!gdown --id 10LpkO5Vm_zOu723FVk6cCeRsv_qyYLdL
!unzip cola_model.zip
!pip install phonemizer
!sudo apt-get install festival

import transformers
import torch
import torch.nn.functional as F
from torch import nn
from torch.cuda.amp import custom_fwd, custom_bwd
from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
from tqdm.auto import tqdm
from phonemizer import phonemize
from phonemizer.separator import Separator
import nltk
nltk.download('punkt')
import sys
sys.path.append("GRUEN")
import GRUEN.Main as gruen
import sys
sys.path.append("/content/detoxify")
from detoxify import Detoxify

def get_festival_phonemes(line):
  phn = phonemize(line, language='en-us', backend='festival', with_stress=False,
      separator=Separator(phone=None, word=' ', syllable="|"), strip=True)
  return phn

text = ["pet pug arthur"]
print(get_festival_phonemes(text))

def count_syllables(doc):
  phrases = doc.split(" / ")
  counts = []
  for p in phrases:
    count = 0
    words = p.split(" ")
    for w in words:
      syllables = w.split("|")
      count += len(syllables)
    counts.append(count)
  return counts

print(count_syllables("waa|shihng dhax dih|shaxz / thihng|kaxng ax|bawt ax hhay|kuw / werdz flow layk wao|ter"))

doc = ["Autumn is nearing. Golden trees tempting my eyes. While songs tease my ears"]
gruen_results = gruen.get_gruen(doc)
print("gruen: ", gruen_results)

from detoxify import Detoxify
results = Detoxify('original').predict("Autumn is nearing / Golden trees tempting my eyes / While songs tease my ears")
for key, value in results.items():
  print(key, ' : ', round(value, 5))

class FrozenBNBLinear(nn.Module):
    def __init__(self, weight, absmax, code, bias=None):
        assert isinstance(bias, nn.Parameter) or bias is None
        super().__init__()
        self.out_features, self.in_features = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None
        self.bias = bias
 
    def forward(self, input):
        output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)
        if self.adapter:
            output += self.adapter(input)
        return output
 
    @classmethod
    def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
        weights_int8, state = quantize_blockise_lowmemory(linear.weight)
        return cls(weights_int8, *state, linear.bias)
 
    def __repr__(self):
        return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
 
 
class DequantizeAndLinear(torch.autograd.Function): 
    @staticmethod
    @custom_fwd
    def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
                absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        ctx.save_for_backward(input, weights_quantized, absmax, code)
        ctx._has_bias = bias is not None
        return F.linear(input, weights_deq, bias)
 
    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output: torch.Tensor):
        assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
        input, weights_quantized, absmax, code = ctx.saved_tensors
        # grad_output: [*batch, out_features]
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        grad_input = grad_output @ weights_deq
        grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
        return grad_input, None, None, None, grad_bias
 
 
class FrozenBNBEmbedding(nn.Module):
    def __init__(self, weight, absmax, code):
        super().__init__()
        self.num_embeddings, self.embedding_dim = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None
 
    def forward(self, input, **kwargs):
        with torch.no_grad():
            # note: both quantuized weights and input indices are *not* differentiable
            weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
            output = F.embedding(input, weight_deq, **kwargs)
        if self.adapter:
            output += self.adapter(input)
        return output 
 
    @classmethod
    def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
        weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
        return cls(weights_int8, *state)
 
    def __repr__(self):
        return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
 
def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
    assert chunk_size % 4096 == 0
    code = None
    chunks = []
    absmaxes = []
    flat_tensor = matrix.view(-1)
    for i in range((matrix.numel() - 1) // chunk_size + 1):
        input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
        quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
        chunks.append(quantized_chunk)
        absmaxes.append(absmax_chunk)
 
    matrix_i8 = torch.cat(chunks).reshape_as(matrix)
    absmax = torch.cat(absmaxes)
    return matrix_i8, (absmax, code)
 
 
def convert_to_int8(model):
    """Convert linear and embedding modules to 8-bit with optional adapters"""
    for module in list(model.modules()):
        for name, child in module.named_children():
            if isinstance(child, nn.Linear):
                print(name, child)
                setattr( 
                    module,
                    name,
                    FrozenBNBLinear(
                        weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                        bias=child.bias,
                    ),
                )
            elif isinstance(child, nn.Embedding):
                setattr(
                    module,
                    name,
                    FrozenBNBEmbedding(
                        weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                    )
                )

class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
    def __init__(self, config):
        super().__init__(config)

        convert_to_int8(self.attn)
        convert_to_int8(self.mlp)


class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)
        

class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)


transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock  # monkey-patch GPT-J

import transformers
config = transformers.GPTJConfig.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")

gpt = torch.load("/content/gpt-j-8bit_deep_haikul.pt",  map_location=torch.device('cuda'))
gpt.eval()

In [None]:
#@title Generate Haikus

import warnings
warnings.filterwarnings("ignore")

topic = 'autumn' #@param {type:"string"}
prompt = "(" + topic + " ="
# print("'" + prompt + "'")
with torch.no_grad():
  prompt_tokens = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
  sample_outputs = gpt.generate(prompt_tokens, max_length=40, do_sample=True, 
    num_return_sequences=20, temperature=0.8) #,  top_p=0.99) #, top_p=0.90, temperature=0.9, top_k=50)
    
haikus = []
haikus_no_slashes = []

for i, sample_output in enumerate(sample_outputs):
  doc = (tokenizer.decode(sample_outputs[i], skip_special_tokens=True))
  parts = doc.split(")")
  haiku_with_prompt = parts[0][1:].strip()
  haiku = haiku_with_prompt.split(" = ")[1].strip()
  haikus.append(haiku)
  haikus_no_slashes.append(haiku.replace(" / ", " "))

scores = gruen.get_gruen(haikus_no_slashes)

print()
print("Deep Haiku generation for", topic +":")
print()

for h, s in zip(haikus, scores):
  hp = []
  parts = h.split(" / ")
  for part in parts:
    p = get_festival_phonemes(part)
    hp.append(p)
  phones = " / ".join(hp)
  syllables = count_syllables(phones)
  if syllables == [5, 7, 5]:
    results = Detoxify('original').predict(h)
    t = results["toxicity"]
    # print(h, syllables, s, t)
    # if True:
    if s > 0.5 and t < 0.9: 
      h = h.replace(". / ", " / ")
      if h.endswith("."):
        h = h[:-1]
      # print(h + "\t" + str(s) + "\t" + str(t))
      print(h)