#### Dependencies

In [None]:
# !pip install torch pytorch_lightning datasets wandb torch

In [None]:
import os
import torch
import tqdm
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertConfig, BertTokenizer, BertTokenizerFast
from datasets import load_dataset
import pytorch_lightning as pl
import wandb

In [None]:
torch.__version__

### Custom Data

In [None]:
# custom dataset class 
class SentimentDataset(Dataset):
    def __init__(self, tokenizer, text, target, max_len=512):
        self.tokenizer = tokenizer
        self.text = text
        self.target = target
        self.max_len =  max_len
    
    def __len__(self):
        return len(self.text)
    
    def __getitem__(self, idx):
        text  = self.text[idx]
        target = self.target[idx]
        
        # encode the text and target into tensors return the attention masks as well
        encoding = self.tokenizer.encode_plus(
            text=text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            pad_to_max_length=True,
            return_attention_mask=True,
            return_tensors='pt',
            truncation=True
        )
        
        return {
          'text': text,
          'input_ids': encoding['input_ids'].flatten(),
          'attention_mask': encoding['attention_mask'].flatten(),
          'targets': torch.tensor(target, dtype=torch.long)
        }
        

### BERTModel PyTorch

In [None]:
class BertClassifier(torch.nn.Module):
    
    def __init__(self, config, model, dim=256, num_classes=2):
        super(BertClassifier, self).__init__()
        
        # create the model config and BERT initialize the pretrained BERT, also layers wise outputs
        self.config = config
        self.base = model
        
        # classifier head [not useful]
        self.head = torch.nn.Sequential(*[
            torch.nn.Dropout(p=self.config.hidden_dropout_prob),
            torch.nn.Linear(in_features=self.config.hidden_size, out_features=dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=self.config.hidden_dropout_prob),
            torch.nn.Linear(in_features=dim, out_features=num_classes)
        ])
    
    def forward(self, input_ids, attention_mask=None):
        
        # first output is top layer output, second output is context of input seq and third output will be layerwise tokens 
        top_layer, pooled, layers = self.base(input_ids, attention_mask)
        outputs = self.head(pooled)
        return top_layer, outputs, layers
        

### Lightning Model

In [None]:
class BertFinetuner(pl.LightningModule):
    
    def __init__(self, model=None, tokenizer=None, data_file="./data/twitter/train.csv", use_cols=['review_text', 'sentiment'], batch_size=32):
        super(BertFinetuner, self).__init__()
        
        # initialize the BERT model c
        self.model = model
        self.data_file = data_file
        self.use_cols = use_cols
        self.batch_size = batch_size
        self.tokenizer = tokenizer
        
        self.f_score= Fbeta()
    
    def accuracy(self, outputs, targets):
        correct = 0
        for i in range(outputs.shape[0]):
            if outputs[i]==targets[i]:
                correct+=1
        return correct/outputs.shape[0]
    
    
    def forward(self, input_ids, attention_mask=None):
        top_layer, outputs, layers =  self.model(input_ids, attention_mask)
        return top_layer, outputs, layers
    
    
    def configure_optimizers(self):
        return torch.optim.Adam(params=self.parameters(), lr=1e-5)
    
    def train_dataloader(self):
        # first 30% data reserved for validation
        train = load_dataset("csv", data_files=self.data_file, split='train[20%:]')
        text, target = train['review_text'], train['sentiment']
        dataset = SentimentDataset(tokenizer=self.tokenizer, text=text, target=target)
        loader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
        return loader
        
    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, targets =  batch['input_ids'], batch['attention_mask'], batch['targets']
        _, logits, _ = self(input_ids, attention_mask)
        loss = F.cross_entropy(logits, targets)
        acc = self.accuracy(logits.argmax(dim=1), targets)
        wandb.log({"Loss": loss, "Accuracy": torch.tensor(acc)})
        return {"loss": loss, "accuracy": torch.tensor(acc)}
    
    def val_dataloader(self):
        # first 30% data reserved for validation
        val = load_dataset("csv", data_files=self.data_file, split='train[:20%]')
        text, target = val['review_text'], val['sentiment']
        dataset = SentimentDataset(tokenizer=self.tokenizer, text=text, target=target)
        loader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
        return loader
        
    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, targets =  batch['input_ids'], batch['attention_mask'], batch['targets']
        _, logits, _ = self(input_ids, attention_mask)
        loss = F.cross_entropy(logits, targets)
        acc = self.accuracy(logits.argmax(dim=1), targets)
#         wandb.log({"val_loss":loss, "val_accuracy":acc})
        self.f_score(logits.argmax(dim=1), targets)
        return {"val_loss": loss, "val_accuracy": torch.tensor(acc)}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['val_accuracy'] for x in outputs]).mean()
        avg_f_score = self.f_score.compute()
        wandb.log({"val_loss":avg_loss, "val_accuracy":avg_acc, "val_fb":avg_f_score})
        return {'val_accuracy': avg_loss, 'val_accuracy': avg_acc, "val_fb":avg_f_score}
    

### Training 

In [None]:
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.metrics import Fbeta 
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBar

In [None]:
ROOT_DIR = "../input/amazonproductsreview/amazon-review/"
DATASET = "books"
NUM_CLASSES = 2
BATCH_SIZE = 16
MAX_LEN = 512
EPOCH = 20

In [None]:
# logger 
logger = WandbLogger(
    name=DATASET,
    save_dir="../working/",
    project="domain-adaptation",
    log_model = True
)

In [None]:
# callbacks
early_stopping = EarlyStopping(
    monitor="val_accuracy",
)
model_checkpoint = ModelCheckpoint(
    filepath="{epoch}-{val_accuracy:.2f}-{val_loss:.2f}",
    monitor="val_accuracy",
    save_top_k=1,
)
progress_bar = ProgressBar()

In [None]:
# create the BERTConfig, BERTTokenizer, and BERTModel 
model_name = "bert-base-uncased"
config = BertConfig.from_pretrained(model_name, output_hidden_states=True)
tokenizer = BertTokenizerFast.from_pretrained(model_name, do_lower_case=True)
bert = BertModel.from_pretrained(model_name, config=config)
classifier = BertClassifier(config=config, model=bert, num_classes=NUM_CLASSES)

In [None]:
model = BertFinetuner(
    model=classifier,
    data_file=os.path.join(ROOT_DIR, DATASET+".csv"),
    tokenizer=tokenizer,
    batch_size=BATCH_SIZE
)

In [None]:
tuner = pl.Trainer(
    logger=logger,
    gpus=[0],
    checkpoint_callback=model_checkpoint,
    max_epochs=EPOCH,
    precision=16
)

In [None]:
tuner.fit(model)

#### Save trained state dictionary
- See section 4 : https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_models_for_inference.html

#### 1. Books

In [None]:
PATH =  DATASET+"-512"+".pt"
# save the model 
torch.save(classifier.state_dict(), PATH)


In [None]:
### Load from state dictionary
classifier_trained = BertClassifier(config=config, model=bert, num_classes=NUM_CLASSES)
classifier_trained.load_state_dict(torch.load(PATH))

# you can evaluate the model on top 20% data

#### DVD

In [None]:
PATH =  DATASET+"-512"+".pt"
# save the model 
torch.save(classifier.state_dict(), PATH)

In [None]:
### Load from state dictionary
classifier_dvd = BertClassifier(config=config, model=bert, num_classes=NUM_CLASSES)
classifier_dvd.load_state_dict(torch.load(PATH))

# you can evaluate the model on top 20% data

In [None]:
## There you go 