In [90]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import BertTokenizer, BertModel
from torch.optim import Adam



In [91]:
#loading dataset
dataset= load_dataset("tweet_eval", "irony")

tokenizer = BertTokenizer.from_pretrained ('bert-base-uncased')

class SarcasmDataset (torch.utils.data.Dataset):
    def __init__ (self, split):
        self.data=dataset[split]
        self.texts = self.data[ 'text']
        self.labels= self.data['label']

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        text=self.texts[idx]
        label=self.labels[idx]
        encoding=tokenizer(text, truncation=True, padding ='max_length', max_length=64, return_tensors='pt')

        return {key: val.squeeze(0) for key,val in encoding.items()}, torch.tensor(label)


train_data=SarcasmDataset('train')
val_data= SarcasmDataset('validation')
train_loader = DataLoader(train_data, batch_size=16 , shuffle =True)
val_loader = DataLoader(val_data, batch_size=32)




In [92]:
len(train_data), len(val_data), type(train_data), type(train_loader), len(train_loader)

(2862,
 955,
 __main__.SarcasmDataset,
 torch.utils.data.dataloader.DataLoader,
 179)

In [93]:
class SarcasmClassifier (nn.Module):
    def __init__(self):
        super().__init__()
        self.bert =BertModel.from_pretrained ('bert-base-uncased')
        self.classifier=nn.Linear(self.bert.config.hidden_size, 1)
        self.sigmoid=nn.Sigmoid()
    def forward(self, input_ids, attention_mask):
        outputs= self.bert(input_ids= input_ids, attention_mask=attention_mask)
        pool_output=outputs.pooler_output
        logits=self.classifier(pool_output)
        return self.sigmoid(logits).squeeze(-1)


model= SarcasmClassifier()


device= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

criterion = nn.BCELoss()
optimizer = Adam(model.parameters(), lr=2e-5)


In [94]:


for epoch in range(3):
    print(f"Current epoch {epoch}")

    model.train()
    total_loss=0
    for batch, labels in train_loader:
        input_ids=batch['input_ids' ].to(device)
        attention_mask= batch['attention_mask'].to(device)
        labels=labels.float().to(device)
        optimizer.zero_grad()
        outputs=model(input_ids, attention_mask)
        loss=criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss+=loss.item()

        print(f"train loss: {total_loss/len(train_loader):.4f}")


    model.eval()
    total_correct = 0
    total =0
    with torch.no_grad():
        for batch, labels in val_loader:
            input_ids= batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels= labels.to(device)
            outputs= model(input_ids, attention_mask)
            preds=(outputs > 0.5).long()
            total_correct+= (preds==labels).sum().item()
            total+=labels.size(0)

    print(f"Validation accuracy: {total_correct/total: .4f}")


Current epoch 0
train loss: 0.0040
train loss: 0.0077
train loss: 0.0116
train loss: 0.0155
train loss: 0.0193
train loss: 0.0230
train loss: 0.0268
train loss: 0.0307
train loss: 0.0347
train loss: 0.0382
train loss: 0.0421
train loss: 0.0462
train loss: 0.0498
train loss: 0.0537
train loss: 0.0576
train loss: 0.0614
train loss: 0.0651
train loss: 0.0689
train loss: 0.0727
train loss: 0.0764
train loss: 0.0801
train loss: 0.0840
train loss: 0.0875
train loss: 0.0914
train loss: 0.0949
train loss: 0.0989
train loss: 0.1026
train loss: 0.1063
train loss: 0.1100
train loss: 0.1135
train loss: 0.1172
train loss: 0.1207
train loss: 0.1245
train loss: 0.1282
train loss: 0.1317
train loss: 0.1355
train loss: 0.1395
train loss: 0.1433
train loss: 0.1468
train loss: 0.1502
train loss: 0.1540
train loss: 0.1574
train loss: 0.1609
train loss: 0.1647
train loss: 0.1679
train loss: 0.1715
train loss: 0.1753
train loss: 0.1791
train loss: 0.1829
train loss: 0.1865
train loss: 0.1901
train loss: 0.1

In [95]:
torch.save(model.state_dict(), "sarcasm_model.pt")


In [96]:
def custom_text(texts: str):
  texts=[texts]
  inputs=tokenizer(texts, padding=True, truncation=True, return_tensors='pt', return_token_type_ids=False)
  inputs={k: v.to(device) for k,v in inputs.items()}

  model.eval()
  model.to(device)
  with torch.no_grad():
    proba=model(**inputs)
  preds=(proba > 0.5).long()

  return [(text,pred.item(), prob.item()) for text, pred, prob in zip(texts,preds, proba)]




In [97]:

def print_result(texts: str):
  results= custom_text(texts)
  for text, pred, prob in results:
    label= "Bazinga" if pred==1 else "Not Funny at all."
    print(f'Text: {text} \n  {label} (probability: {prob: .3f})\n')


In [98]:
print_result("Just what I needed at this hour! a flat tire. Amazing")

Text: Just what I needed at this hour! a flat tire. Amazing 
  Bazinga (probability:  0.829)

