# Reinforcement Learning

In [423]:
import torch
import torch.nn as nn

def training_loop(model = EntityMatrixPredictor(), device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), epochs = 3, pos_weight= 20, model_path=model_path, verbose = False, lambda_penalty = 0.1):    

    model.to(device)
    loss_bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device))

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        train_loader = get_train_loader()

        for batch in train_loader:
            tokens = batch[0]
            target_matrix = batch[1].to(device)  # Ensure targets are on the correct device

            input_ids = tokens["input_ids"].to(device)
            attention_mask = tokens["attention_mask"].to(device)

            # Extract `word_ids` only once for the batch
            word_ids = [tokens.word_ids(batch_index=i) for i in range(len(input_ids))]
            
            optimizer.zero_grad()

            # Forward pass (now includes word_ids)
            predicted_matrix = model(input_ids=input_ids, attention_mask=attention_mask, word_ids=word_ids)

            # **Mask the target matrix to match valid words in predicted matrix**
            batch_size, max_words, _ = predicted_matrix.shape
            target_matrix = target_matrix[:, :max_words, :max_words]  # Trim to match predicted size

            rewards = []
            log_probs = []
            probs = torch.sigmoid(predicted_matrix)

            for i in range(probs.size(0)):  # Loop over batch
                sampled_matrix = torch.bernoulli(probs[i]).detach().cpu()
                reward = compute_loop_reward(sampled_matrix)
                if verbose:
                    print(f"Reward: {reward}")

                rewards.append(reward)

                log_p = (sampled_matrix * torch.log(probs[i] + 1e-6) +
                        (1 - sampled_matrix) * torch.log(1 - probs[i] + 1e-6)).mean()
                log_probs.append(log_p)

            rewards = torch.tensor(rewards, device=probs.device)
            rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-6)
            rl_penalty = -torch.stack(log_probs) @ rewards

            if verbose:
                print(f"RL Penalty: {rl_penalty}")

            # Compute loss directly over the valid portion
            loss = loss_bce(predicted_matrix, target_matrix) + lambda_penalty * rl_penalty
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        # Compute average loss for epoch
        epoch_loss = total_loss / len(train_loader)
        if verbose: print(f"Epoch {epoch+1}, Loss: {epoch_loss}")

    return model

In [424]:
model = training_loop(verbose=True)
torch.save(model.state_dict(), model_path)
print(f"model saved at {model_path}")

Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
RL Penalty: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.017241379246115685
Reward: 0.008620689623057842
Reward: 0.0


KeyboardInterrupt: 