# Exp007: Train grammar rule classifier
This experiment elaborates different ways to train a binary classifier to detect the usage of an EGP rule in a sentence.

In [None]:
import pandas as pd
from dotenv import load_dotenv
load_dotenv()
import os
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from transformers import BertTokenizer, BertModel
from torch.utils.data import Dataset, DataLoader, random_split
import random
from tqdm import tqdm

In [70]:
class RuleDetector(torch.nn.Module):
    def __init__(self, bert_encoder, hidden_dim=32, dropout_rate=0.25, train_bert=False):
        super().__init__()
        self.bert = bert_encoder
        for param in self.bert.parameters():
            param.requires_grad = train_bert
        input_dim = self.bert.config.hidden_size*(self.bert.config.num_hidden_layers+1)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.hidden = torch.nn.Linear(input_dim, hidden_dim)
        self.relu = torch.nn.ReLU()
        self.output = torch.nn.Linear(hidden_dim, 1)
        self.sigmoid = torch.nn.Sigmoid()
    
    def forward(self, input_ids, attention_mask, diagnose=False):
        with torch.no_grad():
            outputs = self.bert(input_ids, attention_mask)
            x = torch.cat(outputs.hidden_states, dim=-1)
        if diagnose:
            print(x.shape)
        x = self.dropout(x)
        x = self.hidden(x)
        if diagnose:
            print(x.shape)
        x = self.relu(x)
        x = self.output(x)
        x = self.sigmoid(x)
        if diagnose:
            print(x)
        x = x * attention_mask.unsqueeze(-1)
        if diagnose:
            print(x)
        
        max_values, max_indices = torch.max(x, 1)
        return max_values.flatten(), max_indices.flatten()

In [3]:
class SentenceDataset(Dataset):
    def __init__(self, sentences, labels, tokenizer, max_len):
        self.sentences = sentences
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        encoding = self.tokenizer(self.sentences[idx], return_tensors='pt', max_length=self.max_len, padding='max_length')
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }
    
def get_dataset(row, tokenizer, max_len, df, random_negatives=True, ratio = 0.5, max_positive_examples=500):
    # assemble dataset for one construction
    # 50% positive examples
    unique_examples = list(set(row['augmented_examples']))
    sentences = unique_examples[:max_positive_examples]
    labels = [1] * len(sentences)

    num_augs = int(len(sentences) * (1-ratio)) if random_negatives else len(sentences)
    # augmented negative examples
    aug_neg_examples = list(set(row['augmented_negative_examples']).difference(set(row['augmented_examples'])))
    random.shuffle(aug_neg_examples)
    unique_negatives = aug_neg_examples[:num_augs]
    sentences += unique_negatives
    labels += [0] * len(unique_negatives)
    
    if random_negatives:
        num_rands = max_positive_examples - len(unique_negatives) # fill to an even number
        # rest: random negative examples (positive from other constructions)
        neg_examples = [example for sublist in df.loc[df['#'] != row['#'], 'augmented_examples'].to_list() for example in sublist]
        random.shuffle(neg_examples)
        sentences += neg_examples[:num_rands]
        labels += [0] * len(neg_examples[:num_rands])
    assert len(sentences) == 2 * max_positive_examples
    assert sum(labels) == max_positive_examples
    return SentenceDataset(sentences, labels, tokenizer, max_len)

def get_loaders(dataset, batch_size=16):
    total_size = len(dataset)
    train_size = int(0.8 * total_size)
    val_size = total_size - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    return train_dataloader, val_dataloader

In [4]:
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir=os.getenv('CACHE_DIR'))
bert_encoder = BertModel.from_pretrained('bert-base-uncased', cache_dir=os.getenv('CACHE_DIR'), output_hidden_states=True)
classifier = RuleDetector(bert_encoder).to(device)

In [5]:
total_params = sum(p.numel() for p in classifier.parameters() if p.requires_grad)
print(f"Total parameters: {total_params}")

Total parameters: 319553


Test model

In [8]:
input_text = "Switzerland is a beautiful country."
encoded_input = bert_tokenizer(input_text, return_tensors='pt').to(device)
max_values, max_indices = classifier(encoded_input['input_ids'], encoded_input['attention_mask'])
max_values

tensor([0.5531], device='cuda:0', grad_fn=<ViewBackward0>)

Train the rule detector for one random rule

In [9]:
egp_examples = pd.read_json("../data/egp_examples.json")

In [328]:
rule = egp_examples.sample(1).iloc[0]
print(rule['Can-do statement'])
print(rule['Example'])
print(random.sample(rule['augmented_examples'], 10))
print(random.sample(rule['augmented_negative_examples'], 10))
classifier = RuleDetector(bert_encoder).to(device)

optimizer = torch.optim.AdamW(classifier.parameters(), 1e-4)
dataset = get_dataset(rule, bert_tokenizer, 64, egp_examples) 
train_dataloader, val_dataloader = get_loaders(dataset)

Can use 'have (got) to' to talk about obligations.
The concert starts at midnight but we have to go before then because we have got to buy our tickets. 

You have to bring your swimming costume.
['They have to complete the online training module by Thursday.', 'She has to take her medicine at the same time every day.', 'We have to leave for the airport at 6 am.', 'We have got to send out the invitations for the party soon.', 'She has got to take her medicine after dinner.', 'We have to wake up early for our flight.', 'We’ve got to buy groceries after work.', 'I have to jog in the mornings to stay fit.', "We've got to leave early to avoid the traffic.", 'He has to submit his report by the end of the day.']
['She returns the library books before they become overdue.', 'We plan our vacation soon.', 'She practices the piano before her recital.', 'Do you take the dog for a walk every morning?', 'She buys groceries for dinner tonight.', 'I finish my homework before I can go out and play.', '

In [248]:
classifier = RuleDetector(bert_encoder).to(device)
optimizer = torch.optim.AdamW(classifier.parameters(), 1e-4)

In [325]:
def train(model, train_dataloader, val_dataloader, num_epochs=3, criterion = torch.nn.BCELoss()):
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        
        for batch in tqdm(train_dataloader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            values, _ = model(input_ids, attention_mask=attention_mask, diagnose=False)
            loss = criterion(values, labels.float())
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_dataloader)
        print(f'Training loss: {avg_train_loss}')

        # Validation phase
        model.eval() 
        total_correct = 0
        total_examples = 0
        
        with torch.no_grad():  # No gradients needed for validation
            for batch in val_dataloader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
    
                outputs, _ = model(input_ids, attention_mask)
                predictions = outputs > 0.5                
                total_correct += (predictions.flatten() == labels).sum().item()
                total_examples += labels.size(0)

        accuracy = total_correct / total_examples
        print(f'Accuracy: {accuracy}')
        
train(classifier, train_dataloader, val_dataloader)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:01<00:00, 36.50it/s]


Training loss: 0.400087109208107
Accuracy: 0.97


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:01<00:00, 38.13it/s]


Training loss: 0.16590621560811997
Accuracy: 0.98


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:01<00:00, 37.56it/s]


Training loss: 0.10522952415049076
Accuracy: 0.975


In [327]:
input_text = "The class was on Monday. It started at 6:00 pm and finished at 7:00 pm."
encoded_input = bert_tokenizer(input_text, return_tensors='pt', max_length=64, padding='max_length').to(device)
with torch.no_grad():
    values, indices = classifier(encoded_input['input_ids'], encoded_input['attention_mask'], diagnose=False)
    print(f'Score: {values.item()}')
tokens = bert_tokenizer.convert_ids_to_tokens(encoded_input['input_ids'].squeeze().tolist())
print(f'Maximum token: {tokens[indices]}')

Score: 0.9988191723823547
Maximum token: 6
