# Main Training Book for Our PCFG

In [1]:
# Import dependencies
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler

import pcfg_models as pcfg_models
import pcfg_loader as loader
import source as source


scaler = GradScaler()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

Device: cuda


  scaler = GradScaler()


In [2]:
# Import our data yielder
data_folder = 'training'
loader_params = {
    'batch_size': 1,
    'shuffle': True,
    'only_inputs': True,
    'print_steps': False,
    'moves_per_step': 1,
    'max_steps': 1,
    'p_use_base': 0.1,
}
data_yielder = loader.get_pcfg_datayielder(data_folder, loader_params)

In [3]:
# Instantiate our Encoder classifier
dsl_list = data_yielder.label_ops
pcfg_encoder = pcfg_models.PCFG_Encoder(
    n_embd = 64, 
    n_head = 4, 
    n_layer = 6, 
    ff_hd = 128, 
    dropout = 0, # This may cause issues if you make this non-zero
    block_size = 35, 
    dsl_mapping = dsl_list, 
    embedding_model_fp = 'vit_11-21-24_400k_v1.pth',
    freeze_emb_model=True, 
    device = device
)
pcfg_encoder = pcfg_encoder.to(device)

# Use this to load in a pre-trained model instead
# filename = ?
# pcfg_encoder = pcfg_models.PCFG_Encoder.load_model(f"trained_pcfg_models/{filename}", print_statements=True, device=device)

  checkpoint = torch.load(path, map_location=torch.device(device))


Vision Transformer instantiated with 398,144 parameters using Sinusoidal encodings.
PCFG encoder instantiated with 202,880 parameters.


In [4]:
def embed_input(input, pcfg_encoder, use_grads, device):
    x = torch.zeros((len(input), pcfg_encoder.model_params['n_embd']), device=device)
    special_tokens = pcfg_encoder.get_special_tokens(device)   # cls_dsl, cls_obj, pad, sep
    
    try:
        first_pad = input[:16].index('<PAD>')
    except ValueError:
        first_pad = 16
    
    for i, obj in enumerate(input):
        if obj == "<PAD>":
            x[i, :] = special_tokens[2, :]
        elif obj == "<SEP>":
            x[i, :] = special_tokens[3, :]
        elif isinstance(obj, source.ARC_Object):
            # x[i, :] = torch.zeros((1, 64), device=device)
            obj.set_embedding(pcfg_encoder.embedding_model, use_grads=use_grads)
            x[i, :] = obj.embedding.to(device)
        else:
            raise NameError("Input contains object that is not '<PAD>', '<SEP>', or an ARC_Object.")
    x = torch.cat((special_tokens[:2, :], x), dim=0)
    return x, first_pad

In [5]:
def compute_loss(dsl_cls, obj_att, dsl_label, obj_label):
    """
    Compute multi-label multi-class loss.

    Args:
        dsl_cls: Tensor of size (len(dsl)) - predicted logits for DSL classes.
        obj_att: Tensor of size (len(obj_att)) - predicted logits for object attention.
        dsl_label: Tensor of integers up to len(dsl), representing ground truth labels for DSL classes.
        obj_label: Tensor of integers up to len(obj_att), representing ground truth labels for object attention.

    Returns:
        Total loss as a scalar tensor.
    """
    # One-hot encode labels for multi-label setup
    dsl_target = torch.zeros_like(dsl_cls, dtype=torch.float)
    dsl_target[dsl_label] = 1.0
    obj_target = torch.zeros_like(obj_att, dtype=torch.float)
    obj_target[obj_label] = 1.0

    # Binary Cross Entropy with logits for multi-label classification
    dsl_loss = F.binary_cross_entropy_with_logits(dsl_cls, dsl_target)
    obj_loss = F.binary_cross_entropy_with_logits(obj_att, obj_target)
    return dsl_loss, obj_loss

In [6]:
# Loop through data
train_params = {
    "epochs": 1,
    "lr": 1e-4,
    "print_frequency": 100,
}

optim = torch.optim.AdamW(pcfg_encoder.parameters(), lr=train_params['lr'])
avg_losses = []
seen_labels = [0] * len(dsl_list)

for epoch in range(train_params['epochs']):
    active = True
    step, dsl_loss_sum, obj_loss_sum, total_loss_sum = 0, 0, 0, 0

    while active:
        try:
            key, input, label, obj_idx = next(data_yielder)
        except StopIteration:
            active = False
            break
        except loader.SampleExhausted:
            continue

        seen_labels[label[0]] += 1
        with torch.amp.autocast('cuda'):
            x, first_pad = embed_input(input, pcfg_encoder, use_grads=not pcfg_encoder.freeze_emb_model, device=device)
            dsl_cls, obj_att = pcfg_encoder(x)
            obj_att = obj_att[2:first_pad + 2]  # Since we have 2 special cls tokens at the start
            
            # Compute individual losses
            dsl_loss, obj_loss = compute_loss(
                dsl_cls, 
                obj_att, 
                torch.tensor(label, device=device), 
                torch.tensor(obj_idx, device=device)
            )

            optim.zero_grad(set_to_none=True)
            # total_loss = dsl_loss + obj_loss
            total_loss = dsl_loss
            scaler.scale(total_loss).backward()
            scaler.step(optim)
            scaler.update()

        # Accumulate losses
        total_loss_sum += total_loss.item()
        dsl_loss_sum += dsl_loss.item()
        obj_loss_sum += obj_loss.item()
        step += 1

        if step % train_params['print_frequency'] == 0:
            avg_total_loss = total_loss_sum / train_params['print_frequency']
            avg_dsl_loss = dsl_loss_sum / train_params['print_frequency']
            avg_obj_loss = obj_loss_sum / train_params['print_frequency']

            avg_losses.append((avg_total_loss, avg_dsl_loss, avg_obj_loss))

            print(f"[{(epoch+1):>2}/{train_params['epochs']:>2}] - {step:>4}: Total Loss = {avg_total_loss:.4f}, DSL Loss = {avg_dsl_loss:.4f}, Obj Loss = {avg_obj_loss:.4f}")

            total_loss_sum = 0
            dsl_loss_sum = 0
            obj_loss_sum = 0
        
        if step > 200:
            break


[ 1/ 1] -  100: Total Loss = 0.5184, DSL Loss = 0.5184, Obj Loss = 0.6971
[ 1/ 1] -  200: Total Loss = 0.4377, DSL Loss = 0.4377, Obj Loss = 0.7080


In [7]:
pcfg_encoder.save_model()

Model and embedding model saved to trained_pcfg_models\pcfg_encoder_20241203_160621.pth
