In [1]:
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, DistilBertConfig, BertForSequenceClassification, BertTokenizer, BertConfig
from datasets import load_dataset,load_metric
import numpy as np


from accelerate import Accelerator


accelerator = Accelerator()
device = accelerator.device




  from .autonotebook import tqdm as notebook_tqdm
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


#  Prediction model example

In [2]:
texts = "I [MASK] this movie!"

# pretrained_name = "distilbert-base-uncased-finetuned-sst-2-english"
pretrained_name = "textattack/bert-base-uncased-SST-2"

# pred_config = DistilBertConfig.from_pretrained(pretrained_name)
# pred_tokenizer = DistilBertTokenizer.from_pretrained(pretrained_name)
# pred_model = DistilBertForSequenceClassification.from_pretrained(pretrained_name).to(device)

pred_config = BertConfig.from_pretrained(pretrained_name)
pred_tokenizer = BertTokenizer.from_pretrained(pretrained_name)
pred_model = BertForSequenceClassification.from_pretrained(pretrained_name).to(device)

inputs = pred_tokenizer(texts, return_tensors="pt")
with torch.no_grad():
    inputs = {key:val.to(device) for key,val in inputs.items()}
    logits = pred_model(**inputs).logits

predicted_class_id = logits.argmax().item()
print(pred_model.config.id2label[predicted_class_id])

print(inputs)

LABEL_1
{'input_ids': tensor([[ 101, 1045,  103, 2023, 3185,  999,  102]], device='cuda:0'), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}


# Explaination model example

In [3]:
# from transformers import DistilBertModel

# explain_pretrained_name = "distilbert-base-uncased-finetuned-sst-2-english"

# explain_tokenizer = DistilBertTokenizer.from_pretrained(explain_pretrained_name)
# explain_model = DistilBertModel.from_pretrained(explain_pretrained_name).to(device)
# explain_config = explain_model.config

# explain_inputs = explain_tokenizer(texts, return_tensors="pt")

# # print(explain_inputs)

# # with torch.no_grad():
# #     explain_inputs = {key:val.to(device) for key,val in explain_inputs.items()}
# #     explain_logits = explain_model(**explain_inputs).logits
# # print(explain_tokenizer)

In [4]:
# pred_hidden_dim = pred_model.config.dim
# num_labels = pred_model.config.num_labels
# explain_hidden_dim = explain_config.dim
# # clip_model.config
# print(pred_hidden_dim, explain_hidden_dim, num_labels)

In [5]:
from maskgen.text_models.text_maskgen_model import MaskGeneratingModel
    
# pred_model.config

In [6]:
# from maskgen.models.mask_generating_model12 import MaskGeneratingModel

pred_hidden_dim = pred_model.config.hidden_size
num_labels = pred_model.config.num_labels

mask_gen_model = MaskGeneratingModel(pred_model, hidden_size=pred_hidden_dim, num_classes=num_labels)
mask_gen_model.to(device)
print()




# Load dataset

In [7]:
from datasets import load_dataset
imdb = load_dataset("imdb")
idx = 0
texts = imdb["test"][idx]['text']
print(texts)

inputs = pred_tokenizer(texts, return_tensors="pt")
with torch.no_grad():
    inputs = {key:val.to(device) for key,val in inputs.items()}
    logits = pred_model(**inputs).logits

predicted_class_id = logits.argmax().item()
pred_label = pred_model.config.id2label[predicted_class_id]
print("pred label", pred_label)
print("True label", pred_model.config.id2label[imdb["test"][idx]['label']])

I love sci-fi and am willing to put up with a lot. Sci-fi movies/TV are usually underfunded, under-appreciated and misunderstood. I tried to like this, I really did, but it is to good TV sci-fi as Babylon 5 is to Star Trek (the original). Silly prosthetics, cheap cardboard sets, stilted dialogues, CG that doesn't match the background, and painfully one-dimensional characters cannot be overcome with a 'sci-fi' setting. (I'm sure there are those of you out there who think Babylon 5 is good sci-fi TV. It's not. It's clichéd and uninspiring.) While US viewers might like emotion and character development, sci-fi is a genre that does not take itself seriously (cf. Star Trek). It may treat important issues, yet not as a serious philosophy. It's really difficult to care about the characters here as they are not simply foolish, just missing a spark of life. Their actions and reactions are wooden and predictable, often painful to watch. The makers of Earth KNOW it's rubbish as they have to alway

In [8]:
from captum.attr import visualization
expl = mask_gen_model.attribute_text(inputs['input_ids'], inputs['attention_mask'])[0]
tokens = pred_tokenizer.convert_ids_to_tokens(inputs['input_ids'].flatten()[1:-1])
vis_data_records = [visualization.VisualizationDataRecord(
                                expl,
                                0,
                                0,
                                0,
                                0,
                                1,       
                                tokens,
                                1)]
                            
visualization.visualize_text(vis_data_records)
# 

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),0.0,1.0,"i love sci - fi and am willing to put up with a lot . sci - fi movies / tv are usually under ##fu ##nded , under - appreciated and misunderstood . i tried to like this , i really did , but it is to good tv sci - fi as babylon 5 is to star trek ( the original ) . silly pro ##st ##hetic ##s , cheap cardboard sets , stil ##ted dialogues , c ##g that doesn ' t match the background , and painfully one - dimensional characters cannot be overcome with a ' sci - fi ' setting . ( i ' m sure there are those of you out there who think babylon 5 is good sci - fi tv . it ' s not . it ' s cl ##iche ##d and un ##ins ##pi ##ring . ) while us viewers might like emotion and character development , sci - fi is a genre that does not take itself seriously ( cf . star trek ) . it may treat important issues , yet not as a serious philosophy . it ' s really difficult to care about the characters here as they are not simply foolish , just missing a spark of life . their actions and reactions are wooden and predictable , often painful to watch . the makers of earth know it ' s rubbish as they have to always say "" gene rod ##den ##berry ' s earth . . . "" otherwise people would not continue watching . rod ##den ##berry ' s ashes must be turning in their orbit as this dull , cheap , poorly edited ( watching it without ad ##vert breaks really brings this home ) tr ##ud ##ging tr ##aba ##nt of a show lumber ##s into space . spoil ##er . so , kill off a main character . and then bring him back as another actor . je ##ee ##z ! dallas all over again ."
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),0.0,1.0,"i love sci - fi and am willing to put up with a lot . sci - fi movies / tv are usually under ##fu ##nded , under - appreciated and misunderstood . i tried to like this , i really did , but it is to good tv sci - fi as babylon 5 is to star trek ( the original ) . silly pro ##st ##hetic ##s , cheap cardboard sets , stil ##ted dialogues , c ##g that doesn ' t match the background , and painfully one - dimensional characters cannot be overcome with a ' sci - fi ' setting . ( i ' m sure there are those of you out there who think babylon 5 is good sci - fi tv . it ' s not . it ' s cl ##iche ##d and un ##ins ##pi ##ring . ) while us viewers might like emotion and character development , sci - fi is a genre that does not take itself seriously ( cf . star trek ) . it may treat important issues , yet not as a serious philosophy . it ' s really difficult to care about the characters here as they are not simply foolish , just missing a spark of life . their actions and reactions are wooden and predictable , often painful to watch . the makers of earth know it ' s rubbish as they have to always say "" gene rod ##den ##berry ' s earth . . . "" otherwise people would not continue watching . rod ##den ##berry ' s ashes must be turning in their orbit as this dull , cheap , poorly edited ( watching it without ad ##vert breaks really brings this home ) tr ##ud ##ging tr ##aba ##nt of a show lumber ##s into space . spoil ##er . so , kill off a main character . and then bring him back as another actor . je ##ee ##z ! dallas all over again ."
,,,,


In [9]:
expl.shape

torch.Size([339])

# training

In [10]:
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
from datasets import load_dataset

imdb = load_dataset("imdb")
train_ds = imdb['train']

# def preprocess_function(examples):
#     return pred_tokenizer(examples["text"], truncation=True, padding="max_length")

# tokenized_imdb = train_ds.map(preprocess_function, batched=True)

def collate_fn(examples):
    texts = [example['text'] for example in examples]
    labels = [example['label'] for example in examples]
    
    # Tokenize texts
    batch = pred_tokenizer(texts, return_tensors="pt", truncation=True, padding=True)
    
    # Add labels
    batch['labels'] = torch.tensor(labels, dtype=torch.long)
    return batch
# train_ds.set_transform(preprocess)


In [11]:

# train_ds[:2]

# training params

In [12]:
batch_size = 256
num_workers = 8
train_dataloader = DataLoader(train_ds, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=num_workers)
n_steps = 2
n_samples = 5

params_to_optimize = [name for name, param in mask_gen_model.named_parameters() if param.requires_grad]
print("params to be optimized: ")
print(params_to_optimize)

params to be optimized: 
['similarity_measure.logit_scale', 'similarity_measure.pred_map.input_layer.weight', 'similarity_measure.pred_map.input_layer.bias', 'similarity_measure.pred_map.layers.0.0.weight', 'similarity_measure.pred_map.layers.0.0.bias', 'similarity_measure.pred_map.layers.0.3.weight', 'similarity_measure.pred_map.layers.0.3.bias', 'similarity_measure.pred_map.layers.0.6.weight', 'similarity_measure.pred_map.layers.0.6.bias', 'similarity_measure.pred_map.layers.1.0.weight', 'similarity_measure.pred_map.layers.1.0.bias', 'similarity_measure.pred_map.layers.1.3.weight', 'similarity_measure.pred_map.layers.1.3.bias', 'similarity_measure.pred_map.layers.1.6.weight', 'similarity_measure.pred_map.layers.1.6.bias', 'similarity_measure.pred_map.output_layer.weight', 'similarity_measure.pred_map.output_layer.bias', 'similarity_measure.explain_map.input_layer.weight', 'similarity_measure.explain_map.input_layer.bias', 'similarity_measure.explain_map.layers.0.0.weight', 'similarit

In [13]:
for idx, data in enumerate(train_dataloader):
    input_ids = data['input_ids'].to(device)
    attention_mask = data['attention_mask'].to(device)
    print(input_ids)
    print(attention_mask)
    if idx >= 0:
        break

tensor([[  101,  2034,  1997,  ...,     0,     0,     0],
        [  101,  1996,  2273,  ...,  7432,  5307,   102],
        [  101,  1045,  2387,  ...,     0,     0,     0],
        ...,
        [  101,  1045,  2031,  ...,  6047,  1998,   102],
        [  101, 12656,  5469,  ...,     0,     0,     0],
        [  101,  1045,  2387,  ...,     0,     0,     0]], device='cuda:0')
tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0')


In [14]:
from tqdm import tqdm

params_to_optimize = [param for param in mask_gen_model.parameters() if param.requires_grad]
# optimizer = torch.optim.Adam(params_to_optimize, lr=1e-3, weight_decay=1e-5)
optimizer = torch.optim.Adam(params_to_optimize, lr=1e-3, weight_decay=1e-5)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.9)


print()

for epoch in range(10):
    pbar = tqdm(train_dataloader)
    for idx, data in enumerate(pbar):
        input_ids = data['input_ids'].to(device)
        attention_mask = data['attention_mask'].to(device)
        loss_dict = mask_gen_model.train_one_batch(input_ids, attention_mask, optimizer=optimizer, n_steps=n_steps, n_samples=n_samples)
        # scheduler.step()
        pbar.set_description(f"Epoch {epoch+1}, Step {idx+1}: Loss = {loss_dict['loss'].item():.4f}, " 
                             f"Reward Loss = {loss_dict['reward_loss'].item():.4f}, "
                            #  f"Regret Loss = {loss_dict['regret_loss'].item():.4f}, "
                             f"Mask Loss = {loss_dict['mask_loss'].item():.4f} "
                            #  f"alt_mask_loss = {loss_dict['alt_mask_loss'].item():.4f} "
                             f"mask_mean = {loss_dict['mask_mean'].item():.4f} "
                             f"prob_mean = {loss_dict['prob_mean'].item():.4f} "
                             )
        if idx % 10 == 0:
            print()
        if (idx) % 10 == 0:
            
            torch.save(mask_gen_model.state_dict(), f'text_mask_gen_model/mask_gen_model_{epoch}_{idx}.pth') 



torch.save(mask_gen_model.state_dict(), f'text_mask_gen_model/mask_gen_model_final_{epoch}_{idx}.pth') 





Epoch 1, Step 1: Loss = 0.3757, Reward Loss = -0.0255, Mask Loss = 0.4012 mask_mean = 0.4067 prob_mean = 0.7590 :   0%|          | 0/98 [01:01<?, ?it/s]




Epoch 1, Step 9: Loss = -0.1438, Reward Loss = -0.4804, Mask Loss = 0.3366 mask_mean = 0.3464 prob_mean = 0.6034 :   9%|▉         | 9/98 [08:09<1:17:58, 52.56s/it]

In [None]:
3262

3262