# Imports & Setup

In [4]:
from transformers import BartTokenizer, BartForConditionalGeneration
import torch
import random


# Load BART Model & Tokenizer

In [5]:
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

model.eval()



Loading weights:   0%|          | 0/259 [00:00<?, ?it/s]

BartForConditionalGeneration(
  (model): BartModel(
    (shared): BartScaledWordEmbedding(50265, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): BartScaledWordEmbedding(50265, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_n

In [55]:
text = "This message is similar to messages that were identified as spam in the past."
print("Original Text:")
print(text)


Original Text:
This message is similar to messages that were identified as spam in the past.


# Tokenize Text

In [56]:
def tokenize(text):
    return tokenizer(text,return_tensors="pt")

# Noise Function

## Token Masking

In [10]:
def token_m(text,mask_prob=0.3):
    words = text.split()
    masked=[]

    for w in words:
        if random.random() < mask_prob:
            masked.append(tokenizer.mask_token)
        else:
            masked.append(w)

    return " ".join(masked)

In [58]:
print(token_m(text))


<mask> message <mask> similar to messages <mask> were identified as spam in the past.


## Token Deletion

In [17]:
def token_deletion(text, delete_prob=0.3):
    words = text.split()
    kept = []
    
    for w in words:
        if random.random() > delete_prob:
            kept.append(w)
    return " ".join(kept)

In [60]:
print(token_deletion(text))


message is to messages that identified spam the past.


## Sentence Permutation

In [25]:
def sentence_permutation(text):
    sentences = text.split(". ")
    random.shuffle(sentences)
    return ". ".join(sentences)


In [27]:
multi_sentence = "BART is powerful. It uses noise. It reconstructs text."
print(sentence_permutation(multi_sentence))


It uses noise. It reconstructs text.. BART is powerful


## Span Infilling

In [33]:
def span_fil(text,span_len=2):
    words = text.split()

    if len(words) <= span_len:
        return text

    start = random.randint(0,len(words)-span_len)
    corrupted = (
        words[:start] + [tokenizer.mask_token] + words[start + span_len:]

    )

    return " ".join(corrupted) 

In [62]:
print(span_fil(text, span_len=3))


This message is similar to <mask> identified as spam in the past.


# Reconstruction

In [66]:
noisy_text = span_fil(text, span_len=3)
print("original text:", text)
print("Noisy:", noisy_text)

original text: This message is similar to messages that were identified as spam in the past.
Noisy: This message is similar to messages that <mask> spam in the past.


In [69]:
inputs = tokenizer(noisy_text,return_tensors="pt")
output = model.generate(**inputs,max_length=50)

print("Reconstructed:", tokenizer.decode(output[0], skip_special_tokens=True))

Reconstructed: This message is similar to messages that have been sent as spam in the past.
