# Exp007: Train grammar rule classifier
This experiment elaborates different ways to train a binary classifier to detect the usage of an EGP rule in a sentence.

In [2]:
import pandas as pd
from dotenv import load_dotenv
load_dotenv()
import os
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from transformers import BertTokenizer, BertModel
from torch.utils.data import Dataset, DataLoader, random_split
import random
from tqdm import tqdm

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd
  from .autonotebook import tqdm as notebook_tqdm


In [91]:
class RuleDetector(torch.nn.Module):
    def __init__(self, bert_encoder, hidden_dim=32, dropout_rate=0.25, train_bert=False):
        super().__init__()
        self.bert = bert_encoder
        for param in self.bert.parameters():
            param.requires_grad = train_bert
        input_dim = self.bert.config.hidden_size*(self.bert.config.num_hidden_layers+1)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.hidden = torch.nn.Linear(input_dim, hidden_dim)
        self.relu = torch.nn.ReLU()
        self.output = torch.nn.Linear(hidden_dim, 1)
        self.sigmoid = torch.nn.Sigmoid()
    
    def forward(self, input_ids, attention_mask, diagnose=False):
        with torch.no_grad():
            outputs = self.bert(input_ids, attention_mask)
            x = torch.cat(outputs.hidden_states, dim=-1)
        if diagnose:
            print(x.shape)
        extended_attention_mask = attention_mask.unsqueeze(-1).expand(x.size())
        x *= extended_attention_mask.float()  # Convert mask to float and apply
        x = self.dropout(x)
        x = self.hidden(x)
        if diagnose:
            print(x.shape)
        x = self.relu(x)
        x = self.output(x)
        x = self.sigmoid(x)
        if diagnose:
            print(x.shape)
        
        max_values, max_indices = torch.max(x, 1)
        return max_values.flatten(), max_indices.flatten()

In [4]:
class SentenceDataset(Dataset):
    def __init__(self, sentences, labels, tokenizer, max_len):
        self.sentences = sentences
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        encoding = self.tokenizer(self.sentences[idx], return_tensors='pt', max_length=self.max_len, padding='max_length')
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }
    
def get_dataset(row, tokenizer, max_len, df, random_negatives=True, ratio = 0.5, max_positive_examples=500):
    # assemble dataset for one construction
    # 50% positive examples
    unique_examples = list(set(row['augmented_examples']))
    sentences = unique_examples[:max_positive_examples]
    labels = [1] * len(sentences)

    num_augs = int(len(sentences) * (1-ratio)) if random_negatives else len(sentences)
    # augmented negative examples
    aug_neg_examples = list(set(row['augmented_negative_examples']).difference(set(row['augmented_examples'])))
    random.shuffle(aug_neg_examples)
    unique_negatives = aug_neg_examples[:num_augs]
    sentences += unique_negatives
    labels += [0] * len(unique_negatives)
    
    if random_negatives:
        num_rands = max_positive_examples - len(unique_negatives) # fill to an even number
        # rest: random negative examples (positive from other constructions)
        neg_examples = [example for sublist in df.loc[df['#'] != row['#'], 'augmented_examples'].to_list() for example in sublist]
        random.shuffle(neg_examples)
        sentences += neg_examples[:num_rands]
        labels += [0] * len(neg_examples[:num_rands])
    assert len(sentences) == 2 * max_positive_examples
    assert sum(labels) == max_positive_examples
    return SentenceDataset(sentences, labels, tokenizer, max_len)

def get_loaders(dataset, batch_size=16):
    total_size = len(dataset)
    train_size = int(0.8 * total_size)
    val_size = total_size - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    return train_dataloader, val_dataloader

In [85]:
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir=os.getenv('CACHE_DIR'))
bert_encoder = BertModel.from_pretrained('bert-base-uncased', cache_dir=os.getenv('CACHE_DIR'), output_hidden_states=True)
classifier = RuleDetector(bert_encoder).to(device)

In [6]:
total_params = sum(p.numel() for p in classifier.parameters() if p.requires_grad)
print(f"Total parameters: {total_params}")

Total parameters: 319553


Test model

In [8]:
input_text = "Switzerland is a beautiful country."
encoded_input = bert_tokenizer(input_text, return_tensors='pt').to(device)
output = classifier(encoded_input['input_ids'], encoded_input['attention_mask'])
output['max_values']

tensor([0.5706], device='cuda:0', grad_fn=<ViewBackward0>)

Train the rule detector for one random rule

In [9]:
egp_examples = pd.read_json("../data/egp_examples.json")

In [79]:
rule = egp_examples.sample(1).iloc[0]
print(rule['Can-do statement'])
print(rule['Example'])
print(random.sample(rule['augmented_examples'], 10))
print(random.sample(rule['augmented_negative_examples'], 10))
classifier = RuleDetector(bert_encoder).to(device)

optimizer = torch.optim.AdamW(classifier.parameters(), 1e-4)
dataset = get_dataset(rule, bert_tokenizer, 64, egp_examples) 
train_dataloader, val_dataloader = get_loaders(dataset)

Can use a limited range of expressions with 'be' + infinitive ('be allowed to', 'be supposed to', 'be able to') with present and past forms of 'be' and with modal 'will'. 
First of all, if you are allowed to go out of the building in your break, you should do it. 

Perhaps you will be allowed to go on holiday with your friends next year. 

The film is supposed to start at 7.00 pm so we'd better meet at 6.30 pm. See you there! 

I was supposed to be meeting my friend Laura but she didn't come. 

I am sorry but I  am not able to meet you next Tuesday. 

We were able to choose the songs ourselves and so I liked them very much.
['I am supposed to call my grandmother to wish her a happy birthday.', 'The team was able to win the game despite the tough competition.', 'Were they allowed to study together for the test?', 'We are allowed to bring our own snacks to the movie theater.', 'The guests are supposed to arrive at the party by 8 pm.', 'I think she will be able to finish the marathon if s

In [95]:
classifier = RuleDetector(bert_encoder).to(device)
optimizer = torch.optim.AdamW(classifier.parameters(), 1e-4)

In [96]:
def train(model, train_dataloader, val_dataloader, num_epochs=3, criterion = torch.nn.BCELoss()):
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        
        for batch in tqdm(train_dataloader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            values, _ = model(input_ids, attention_mask=attention_mask)
            loss = criterion(values, labels.float())
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_dataloader)
        print(f'Training loss: {avg_train_loss}')

        # Validation phase
        model.eval() 
        total_correct = 0
        total_examples = 0
        
        with torch.no_grad():  # No gradients needed for validation
            for batch in val_dataloader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
    
                outputs, _ = model(input_ids, attention_mask)
                predictions = outputs > 0.5                
                total_correct += (predictions.flatten() == labels).sum().item()
                total_examples += labels.size(0)

        accuracy = total_correct / total_examples
        print(f'Accuracy: {accuracy}')
        
train(classifier, train_dataloader, val_dataloader)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:01<00:00, 35.13it/s]


Training loss: 0.4804320377111435
Accuracy: 0.465


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:01<00:00, 36.09it/s]


Training loss: 0.35990911960601807
Accuracy: 0.99


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:01<00:00, 37.76it/s]


Training loss: 0.35357400119304655
Accuracy: 0.995


In [99]:
input_text = 'I want to call my grandmother to wish her a happy birthday'
encoded_input = bert_tokenizer(input_text, return_tensors='pt', max_length=64, padding='max_length').to(device)
with torch.no_grad():
    values, indices = classifier(encoded_input['input_ids'], encoded_input['attention_mask'], diagnose=True)
    print(f'Score: {values.item()}')
tokens = bert_tokenizer.convert_ids_to_tokens(encoded_input['input_ids'].squeeze().tolist())
print(f'Maximum token: {tokens[indices]}')

torch.Size([1, 64, 9984])
torch.Size([1, 64, 32])
torch.Size([1, 64, 1])
Score: 0.4964766800403595
Maximum token: [PAD]


In [62]:
list(zip(encoded_input['input_ids'][0].cpu().tolist(), encoded_input['attention_mask'][0].cpu().tolist()))

[(101, 1),
 (1045, 1),
 (2089, 1),
 (2036, 1),
 (2272, 1),
 (4826, 1),
 (1012, 1),
 (102, 1),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0),
 (0, 0)]